diff --git a/.github/workflows/ci-ut.yml b/.github/workflows/ci-ut.yml index d22b457404..c780ad4891 100644 --- a/.github/workflows/ci-ut.yml +++ b/.github/workflows/ci-ut.yml @@ -82,7 +82,6 @@ jobs: ${{ matrix.mode.config }} \ --curses=no \ --color=yes \ - --strategy=TestRunner=local \ -- \ @jax//tests:gpu_tests \ @jax//tests:backend_independent_tests \ diff --git a/jax_rocm_plugin/build/rocm/jax.bazelrc b/jax_rocm_plugin/build/rocm/jax.bazelrc index 52e4abbdfd..e519c10669 100644 --- a/jax_rocm_plugin/build/rocm/jax.bazelrc +++ b/jax_rocm_plugin/build/rocm/jax.bazelrc @@ -14,9 +14,15 @@ build:rocm --repo_env=HERMETIC_PYTHON_VERSION=3.12 build:rocm_mgpu --test_tag_filters=jax_test_gpu,multiaccelerator,-config-cuda-only,-manual build:rocm_mgpu --build_tag_filters=jax_test_gpu,multiaccelerator,-config-cuda-only,-manual +build:rocm_mgpu --host_platform="//platform/linux:tf_linux_multigpu" +build:rocm_mgpu --extra_execution_platforms="//platform/linux:tf_linux_multigpu" +build:rocm_mgpu --platforms="//platform/linux:tf_linux_multigpu" build:rocm_sgpu --test_tag_filters=jax_test_gpu,-multiaccelerator,-config-cuda-only,-manual build:rocm_sgpu --build_tag_filters=jax_test_gpu,-multiaccelerator,-config-cuda-only,-manual +build:rocm_sgpu --host_platform="//platform/linux:tf_linux_gpu" +build:rocm_sgpu --extra_execution_platforms="//platform/linux:tf_linux_gpu" +build:rocm_sgpu --platforms="//platform/linux:tf_linux_gpu" # for @xla//build_tools/rocm:parallel_gpu_execute build:rocm --legacy_external_runfiles=true diff --git a/jax_rocm_plugin/platform/linux/BUILD b/jax_rocm_plugin/platform/linux/BUILD index d6418fce23..7f94168f5c 100644 --- a/jax_rocm_plugin/platform/linux/BUILD +++ b/jax_rocm_plugin/platform/linux/BUILD @@ -43,3 +43,18 @@ platform( "Pool": "linux_x64_gpu", }, ) + +platform( + name = "tf_linux_multigpu", + constraint_values = [ + "@platforms//os:linux", + "@platforms//cpu:x86_64", + "@bazel_tools//tools/cpp:clang", + ], + # note this image shall match the container one executes the build command! + exec_properties = { + "container-image": "docker://rocm/tensorflow-build@sha256:7fcfbd36b7ac8f6b0805b37c4248e929e31cf5ee3af766c8409dd70d5ab65faa", + "OSFamily": "Linux", + "Pool": "linux_x64_multigpu", + }, +)