Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tensorflow/workspace2.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ load("@local_xla//third_party/nvshmem:workspace.bzl", nvshmem = "repo")
load("@local_xla//third_party/pybind11_abseil:workspace.bzl", pybind11_abseil = "repo")
load("@local_xla//third_party/pybind11_bazel:workspace.bzl", pybind11_bazel = "repo")
load("@local_xla//third_party/robin_map:workspace.bzl", robin_map = "repo")
load("@local_xla//third_party/openblas:workspace.bzl", openblas = "repo")
load("@rules_jvm_external//:defs.bzl", "maven_install")
load("@tf_runtime//:dependencies.bzl", "tfrt_dependencies")
load("//tensorflow/tools/def_file_filter:def_file_filter_configure.bzl", "def_file_filter_configure")
Expand Down Expand Up @@ -100,6 +101,7 @@ def _initialize_third_party():
tensorrt()
nvshmem()
triton()
openblas()

# copybara: tsl vendor

Expand Down
Empty file.
17 changes: 17 additions & 0 deletions third_party/xla/third_party/openblas/openblas.BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
genrule(
name = "build_openblas",
srcs = glob(["**"], exclude = ["*.a"]),
outs = ["libopenblas.a"],
cmd = """
cd $$(dirname $(location //:README.md)) && \
make NO_SHARED=1 ONLY_CBLAS=1 TARGET=ARMV8 ARCH=arm64 && \
cd - && \
cp $$(dirname $(location //:README.md))/libopenblas_*.a $@
""",
)

cc_import(
name = "openblas",
static_library = "libopenblas.a",
visibility = ["//visibility:public"],
)
10 changes: 10 additions & 0 deletions third_party/xla/third_party/openblas/workspace.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")

def repo():
tf_http_archive(
name = "openblas",
strip_prefix = "OpenBLAS-8795fc7985635de1ecf674b87e2008a15097ffab",
sha256 = "f5ff825b3a82417d47c2ba97606ce8a5d868f863e555025f5d4112e6dfd62e2f",
urls = tf_mirror_urls("https://github.com/OpenMathLib/OpenBLAS/archive/8795fc7985635de1ecf674b87e2008a15097ffab.tar.gz"),
build_file = "//third_party/openblas:openblas.BUILD",
)
2 changes: 2 additions & 0 deletions third_party/xla/workspace2.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ load("//third_party/shardy:workspace.bzl", shardy = "repo")
load("//third_party/stablehlo:workspace.bzl", stablehlo = "repo")
load("//third_party/triton:workspace.bzl", triton = "repo")
load("//third_party/uv:workspace.bzl", uv = "repo")
load("//third_party/openblas:workspace.bzl", openblas = "repo")

def _initialize_third_party():
""" Load third party repositories. See above load() statements. """
Expand All @@ -31,6 +32,7 @@ def _initialize_third_party():
stablehlo()
triton()
uv()
openblas()

# Define all external repositories required by TensorFlow
def _tf_repositories():
Expand Down
12 changes: 12 additions & 0 deletions third_party/xla/xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_cpu_use_fusion_emitters(true);
opts.set_xla_cpu_use_thunk_runtime(true);
opts.set_xla_cpu_use_xnnpack(false);
opts.set_xla_cpu_enable_xnnpack(false); // For softmax
opts.set_xla_cpu_use_kernel_selector(false);
opts.set_xla_cpu_experimental_xnn_graph_fusion_mode(
DebugOptions::XNN_GRAPH_FUSION_MODE_DISABLED);
opts.set_xla_cpu_parallel_codegen_split_count(32);
Expand Down Expand Up @@ -994,6 +996,16 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
bool_setter_for(&DebugOptions::set_xla_cpu_use_xnnpack),
debug_options->xla_cpu_use_xnnpack(),
"Use XNNPACK for supported operations."));
flag_list->push_back(tsl::Flag(
"xla_cpu_enable_xnnpack",
bool_setter_for(&DebugOptions::set_xla_cpu_enable_xnnpack),
debug_options->xla_cpu_enable_xnnpack(),
"Enable XNNPACK ops rewriter."));
flag_list->push_back(tsl::Flag(
"xla_cpu_use_kernel_selector",
bool_setter_for(&DebugOptions::set_xla_cpu_use_kernel_selector),
debug_options->xla_cpu_use_kernel_selector() ,
"Replace dot with custom call to libraries."));
flag_list->push_back(tsl::Flag(
"xla_cpu_experimental_xnn_graph_fusion_mode",
setter_for_xla_cpu_experimental_xnn_graph_fusion_mode,
Expand Down
113 changes: 109 additions & 4 deletions third_party/xla/xla/service/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ filegroup(
"runtime_single_threaded_matmul_s32.cc",
"runtime_single_threaded_matmul_u8.cc",
"runtime_topk.cc",
"xnnpack_ops.cc",
# Multi-threaded support.
"runtime_conv2d.cc",
"runtime_conv3d.cc",
Expand All @@ -88,6 +89,7 @@ filegroup(
"runtime_matmul_f64.cc",
"runtime_matmul_s32.cc",
"runtime_fork_join.cc",
"kernel_selector.cc",
"//xla/backends/cpu/runtime:runtime_srcs",
#"runtime_handle_ffi_call.cc", # TODO(b/338344732): Add "runtime_handle_ffi_call.cc".
],
Expand All @@ -109,13 +111,15 @@ filegroup(
"runtime_single_threaded_fft.h",
"runtime_single_threaded_matmul.h",
"runtime_topk.h",
"xnnpack_ops.h",
# Multi-threaded support.
"runtime_conv2d.h",
"runtime_conv3d.h",
"runtime_fft.h",
"runtime_fork_join.h",
"runtime_lightweight_check.h",
"runtime_matmul.h",
"kernel_selector.h",
"//xla/backends/cpu/runtime:runtime_hdrs",
#"runtime_handle_ffi_call.h", # TODO(b/338344732): Add "runtime_handle_ffi_call.h"
],
Expand Down Expand Up @@ -193,7 +197,11 @@ cc_library(
name = "cpu_compiler_pure",
srcs = ["cpu_compiler.cc"],
hdrs = ["cpu_compiler.h"],
copts = tsl_copts(),
copts = tsl_copts() + select({
":enable_blas_mlir": ["-DENABLE_BLAS_MLIR"],
":disable_blas_mlir": [],
"//conditions:default": [],
}),
deps = [
":buffer_info_util",
":conv_canonicalization",
Expand All @@ -218,6 +226,8 @@ cc_library(
":small_while_loop_hoisting_pass",
":thunk_emitter",
":xla_framework",
":xnnpack_ops_rewriter",
":kernel_selector_ops_rewriter",
"//xla:cpu_function_runtime",
"//xla:debug_options_flags",
"//xla:literal",
Expand Down Expand Up @@ -417,7 +427,21 @@ cc_library(
"@llvm-project//llvm:SystemZCodeGen", # fixdeps: keep
]) + if_llvm_x86_available([
"@llvm-project//llvm:X86CodeGen", # fixdeps: keep
]),
]) + select({
":enable_blas_mlir": [":libmlir"],
":disable_blas_mlir": [],
"//conditions:default": [],
}),
)

config_setting(
name = "enable_blas_mlir",
define_values = {"ENABLE_BLAS_MLIR": "true"},
)

config_setting(
name = "disable_blas_mlir",
define_values = {"ENABLE_BLAS_MLIR": "false"},
)

cc_library(
Expand Down Expand Up @@ -592,7 +616,11 @@ cc_library(
"windows_compatibility.h",
],
hdrs = ["runtime_symbol_generator.h"],
copts = if_enable_acl(["-DXLA_CPU_USE_ACL=1"]) + tsl_copts(),
copts = if_enable_acl(["-DXLA_CPU_USE_ACL=1"]) + tsl_copts() + select({
":enable_blas_mlir": ["-DENABLE_BLAS_MLIR"],
":disable_blas_mlir": [],
"//conditions:default": [],
}),
deps = [
":cpu_runtime",
":onednn_convolution",
Expand All @@ -617,6 +645,8 @@ cc_library(
":runtime_single_threaded_fft",
":runtime_single_threaded_matmul",
":runtime_topk",
":xnnpack_ops",
":kernel_selector",
"//xla/service:custom_call_target_registry",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/strings:string_view",
Expand Down Expand Up @@ -1102,7 +1132,11 @@ cc_library(
"cpu_runtime.h",
"xfeed_manager.h",
],
copts = runtime_copts(),
copts = runtime_copts() + select({
":enable_blas_mlir": ["-DENABLE_BLAS_MLIR"],
":disable_blas_mlir": [],
"//conditions:default": [],
}),
deps = [
":cpu_executable_run_options",
"//xla:executable_run_options",
Expand Down Expand Up @@ -2187,3 +2221,74 @@ xla_cc_test(
"@local_tsl//tsl/platform:test",
],
)

cc_library(
name = "xnnpack_ops_rewriter",
srcs = ["xnnpack_ops_rewriter.cc"],
hdrs = [
"xnnpack_ops_rewriter.h",
"xnnpack_pattern_utils.h",
],
copts = ["-O3"],
visibility = ["//visibility:public"],
deps = [
"//xla/hlo/ir:hlo",
"//xla:literal_comparison",
"//xla:literal_util",
"//xla:status_macros",
"//xla/hlo/pass:hlo_pass",
"//xla/service:pattern_matcher",
],
)

cc_library(
name = "xnnpack_ops",
srcs = ["xnnpack_ops.cc"],
hdrs = ["xnnpack_ops.h"],
copts = ["-O3"],
visibility = ["//visibility:public"],
deps = [
"@XNNPACK",
"@com_google_absl//absl/base",
],
)

cc_library(
name = "kernel_selector",
srcs = ["kernel_selector.cc"],
hdrs = ["kernel_selector.h"],
copts = ["-O3"] + select({
":enable_blas_mlir": ["-DENABLE_BLAS_MLIR"],
":disable_blas_mlir": [],
"//conditions:default": [],
}),
visibility = ["//visibility:public"],
deps = [
":runtime_lightweight_check",
"//xla:executable_run_options",
"@eigen_archive//:eigen3",
"@local_tsl//tsl/platform:blocking_counter",
"@openblas//:openblas",
],
)

cc_library(
name = "kernel_selector_ops_rewriter",
srcs = ["kernel_selector_ops_rewriter.cc"],
hdrs = ["kernel_selector_ops_rewriter.h"],
copts = ["-O3"],
visibility = ["//visibility:public"],
deps = [
":cpu_runtime",
"//xla/hlo/ir:hlo",
"//xla:literal_util",
"//xla/hlo/pass:hlo_pass",
],
)

cc_import(
name = "libmlir",
visibility = ["//visibility:public"],
shared_library = "//xla/service/libs:libblas_mlir.so",
system_provided = 0
)
20 changes: 19 additions & 1 deletion third_party/xla/xla/service/cpu/cpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ limitations under the License.
#include "xla/service/cpu/runtime_symbol_generator.h"
#include "xla/service/cpu/small_while_loop_hoisting_pass.h"
#include "xla/service/cpu/thunk_emitter.h"
#include "xla/service/cpu/xnnpack_ops_rewriter.h"
#include "xla/service/cpu/kernel_selector_ops_rewriter.h"
#include "xla/service/cpu_gpu_shape_verifier.h"
#include "xla/service/dump.h"
#include "xla/service/dynamic_dimension_inference.h"
Expand Down Expand Up @@ -591,6 +593,12 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn(
};
pipeline.AddPass<OperandUpcaster>(upcaster_filter);

// For softmax, rewrite to custom calls with XNNPACK targets.
bool enable_xnnpack =
xla::GetDebugOptionsFromFlags().xla_cpu_enable_xnnpack();
if (enable_xnnpack)
pipeline.AddPass<XnnPackOpsRewriter>();

// Expand random number generation.
pipeline.AddPass<RngExpander>();
pipeline.AddPass<RngBitGeneratorExpander>(RandomAlgorithm::RNG_PHILOX);
Expand Down Expand Up @@ -831,6 +839,13 @@ absl::Status CpuCompiler::RunHloPassesAfterLayoutAssn(

pipeline.AddPass<ReshapeDecomposer>();

bool use_kernel_selector =
xla::GetDebugOptionsFromFlags().xla_cpu_use_kernel_selector();
if (use_kernel_selector) {
// This pass rewrites hlo.dot into custom calls.
pipeline.AddPass<KernelSelectorOpsRewriter>();
}

const int max_parallelism =
module->config().intra_op_parallelism_threads() > 0
? module->config().intra_op_parallelism_threads()
Expand Down Expand Up @@ -863,7 +878,10 @@ absl::Status CpuCompiler::RunHloPassesAfterLayoutAssn(
}

// Add a fusion pass now that layout assignment is done.
pipeline.AddPass<CpuInstructionFusion>();
if (getenv("SET_CPU_INS_FUSION_NOT_DUPLICATE") != NULL)
pipeline.AddPass<CpuInstructionFusion>(/*may_duplicate=*/false);
else
pipeline.AddPass<CpuInstructionFusion>(/*may_duplicate=*/true);
if (is_fusion_emitters) {
pipeline.AddPass<FusionWrapper>();
}
Expand Down
4 changes: 2 additions & 2 deletions third_party/xla/xla/service/cpu/cpu_instruction_fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ namespace cpu {

class CpuInstructionFusion : public InstructionFusion {
public:
CpuInstructionFusion()
: InstructionFusion(CpuInstructionFusion::IsExpensive) {}
CpuInstructionFusion(bool may_duplicate)
: InstructionFusion(CpuInstructionFusion::IsExpensive, may_duplicate) {}
~CpuInstructionFusion() override = default;

using HloPassInterface::Run;
Expand Down
48 changes: 48 additions & 0 deletions third_party/xla/xla/service/cpu/cpu_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,54 @@ extern const char* const kOneDnnMatMulReorderSymbolName =
"__xla_cpu_runtime_OneDnnMatMulReorder";
extern const char* const kHandleFfiCallSymbolName =
"__xla_cpu_runtime_HandleFfiCall";
extern const char* const kXnnPackSoftMaxNDSymbolName =
"__xla_cpu_runtime_XnnPackSoftMaxND";
extern const char* const kArgMax3DParallelSymbolName =
"__xla_cpu_runtime_ArgMax3DParallel";
extern const char* const kArgMax3DSequentialSymbolName =
"__xla_cpu_runtime_ArgMax3DSequential";
extern const char* const kKernelSelectorGEMVSymbolName =
"__xla_cpu_runtime_KernelSelectorGEMV";
extern const char* const kKernelSelectorGEMMSequentialSymbolName =
"__xla_cpu_runtime_KernelSelectorGEMMSequential";
extern const char* const kKernelSelectorGEMMParallelSymbolName =
"__xla_cpu_runtime_KernelSelectorGEMMParallel";
extern const char* const kKernelSelectorBatch3DSequentialSymbolName =
"__xla_cpu_runtime_KernelSelectorBatch3DSequential";
extern const char* const kKernelSelectorBatch3DParallelSymbolName =
"__xla_cpu_runtime_KernelSelectorBatch3DParallel";
#ifdef ENABLE_BLAS_MLIR
extern const char* const kKernelSelectorGEMVMLIRSymbolName =
"__xla_cpu_runtime_KernelSelectorGEMVMLIR";
#endif // ENABLE_BLAS_MLIR
extern const char* const kKernelSelectorBatch4DSequentialSymbolName =
"__xla_cpu_runtime_KernelSelectorBatch4DSequential";
extern const char* const kKernelSelectorBatch4DParallelSymbolName =
"__xla_cpu_runtime_KernelSelectorBatch4DParallel";
#ifdef ENABLE_BLAS_MLIR
extern const char* const kKernelSelectorGEMMMLIRSymbolName =
"__xla_cpu_runtime_KernelSelectorGEMMMLIR";
extern const char* const kKernelSelectorBatch3DMLIRSymbolName =
"__xla_cpu_runtime_KernelSelectorBatch3DMLIR";
extern const char* const kKernelSelectorBatch4DMLIRSymbolName =
"__xla_cpu_runtime_KernelSelectorBatch4DMLIR";
#endif // ENABLE_BLAS_MLIR
extern const char* const kKernelSelectorGEMVEmptySymbolName =
"__xla_cpu_runtime_KernelSelectorGEMVEmpty";
extern const char* const kKernelSelectorGEMMEmptySymbolName =
"__xla_cpu_runtime_KernelSelectorGEMMEmpty";
extern const char* const kKernelSelectorBatch3DEmptySymbolName =
"__xla_cpu_runtime_KernelSelectorBatch3DEmpty";
extern const char* const kKernelSelectorBatch4DEmptySymbolName =
"__xla_cpu_runtime_KernelSelectorBatch4DEmpty";
extern const char* const kArgMax3DEmptySymbolName =
"__xla_cpu_runtime_ArgMax3DEmpty";
extern const char* const kKernelSelectorOperationGEMV = "GEMV";
extern const char* const kKernelSelectorOperationGEMM = "GEMM";
extern const char* const kKernelSelectorOperationBATCH3D = "BATCH3D";
extern const char* const kKernelSelectorOperationBATCH4D = "BATCH4D";
extern const char* const kKernelSelectorOperationARGMAX = "ARGMAX";
extern const char* const kCustomCallKernelSelector = "KernelSelector";

namespace {

Expand Down
Loading