From 92148e0b108c3547ca73c2b08b2149b6bf349e75 Mon Sep 17 00:00:00 2001 From: "Jinyun (Joey) Ye" Date: Wed, 3 Dec 2025 20:43:14 +0800 Subject: [PATCH 01/16] Enable serving build --- tensorflow/tools/toolchains/python/python_repo.bzl | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/tools/toolchains/python/python_repo.bzl b/tensorflow/tools/toolchains/python/python_repo.bzl index 2af9b29d7af20b..158fd77bd8038b 100644 --- a/tensorflow/tools/toolchains/python/python_repo.bzl +++ b/tensorflow/tools/toolchains/python/python_repo.bzl @@ -21,6 +21,7 @@ TF_PYTHON_VERSION = "{}" HERMETIC_PYTHON_VERSION = "{}" WHEEL_NAME = "{}" WHEEL_COLLAB = "{}" +USE_PYWRAP_RULES = "False" """ def _python_repository_impl(repository_ctx): From 66aca809f27c0d0ec163b2dcb90dc67e4422cf71 Mon Sep 17 00:00:00 2001 From: "Jinyun (Joey) Ye" Date: Fri, 5 Dec 2025 01:00:24 +0800 Subject: [PATCH 02/16] [Huawei] Add debug option tf_xla_annotate_cluster_id Adding --tf_xla_annotate_cluster_id to allow operator name starting with .cluster.id to influcence clustering decisions --- tensorflow/compiler/jit/flags.cc | 7 +++ tensorflow/compiler/jit/flags.h | 3 + .../compiler/jit/mark_for_compilation_pass.cc | 62 +++++++++++++++++++ 3 files changed, 72 insertions(+) diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index 10756ddf9de7b5..d898298f0d1c85 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -108,6 +108,12 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector* flag_list) { Flag("tf_xla_max_cluster_size", &mark_for_compilation_flags->tf_xla_max_cluster_size, "Maximum number of operators in an XLA compilation."), + Flag("tf_xla_annotate_cluster_id", + &mark_for_compilation_flags->tf_xla_annotate_cluster_id, + "Allow operator names to influence clustering scheume." + "Operators whose name starting with .cluster.{id} will likely" + "to be clustered together if the ids are the same number. " + ".cluster.none will not be clustered with those having numbered id"), Flag( "tf_xla_ops_to_cluster", &mark_for_compilation_flags->tf_xla_ops_to_cluster, @@ -232,6 +238,7 @@ void AllocateAndParseFlags() { mark_for_compilation_flags->tf_xla_min_cluster_size = 4; mark_for_compilation_flags->tf_xla_max_cluster_size = std::numeric_limits::max(); + mark_for_compilation_flags->tf_xla_annotate_cluster_id = false; mark_for_compilation_flags->tf_xla_clustering_debug = false; mark_for_compilation_flags->tf_xla_cpu_global_jit = false; mark_for_compilation_flags->tf_xla_clustering_fuel = diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index 0d0c5082cf9a82..e9e14fbabdd7aa 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -62,6 +62,9 @@ struct MarkForCompilationPassFlags { // Maximum number of operators in an XLA compilation. int32 tf_xla_max_cluster_size; + // Enable operator name to influence clustering decision + bool tf_xla_annotate_cluster_id; + // If non-empty, limit XLA clustering to the following TF operations. string tf_xla_ops_to_cluster; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index c3a24f3e0f7163..584be983432789 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -123,6 +124,9 @@ class MarkForCompilationPassImpl { std::atomic* fuel; bool dump_graphs; + + // Enable models to influcence clustering with operator names + int annotate_cluster_id; }; MarkForCompilationPassImpl(DebugOptions debug_options, Graph* graph, @@ -245,7 +249,11 @@ class MarkForCompilationPassImpl { " others #", cycles_graph_node_id(), ">"); } + int annotated_id() const { return annotated_id_; } + void set_annotated_id(int id) { annotated_id_ = id; } + private: + int annotated_id_ = -1; int cluster_size_ = 1; int cycles_graph_node_id_; int effective_cluster_size_; @@ -317,6 +325,8 @@ class MarkForCompilationPassImpl { return compilation_candidates_.find(n) != compilation_candidates_.end(); } + absl::Status AssignAnnotatedClusterIDs(); + // Tries to contract the edge from cluster `from` to cluster `to`. Returns // true if successful. absl::StatusOr TryToContractEdge(Cluster* from, Cluster* to); @@ -686,6 +696,12 @@ absl::StatusOr MarkForCompilationPassImpl::Initialize() { // representative names the node in the 'cycles' graph that represents the // cluster. TF_RETURN_IF_ERROR(BuildInitialClusterSet()); + + // Source model may be annotated with preferred clusters. This function + // just interpreter the annotations and assign preferred IDs + if (debug_options_.annotate_cluster_id) { + TF_RETURN_IF_ERROR(AssignAnnotatedClusterIDs()); + } return true; } @@ -1548,6 +1564,45 @@ bool MarkForCompilationPassImpl::CompilationDisallowedByXlaCompileAttr( return false; } +absl::Status MarkForCompilationPassImpl::AssignAnnotatedClusterIDs(void) { + VLOG(1) << "Run AssignAnnotatedClusterIDs"; + + for (Node* node : graph_->nodes()) { + Cluster * cluster = GetClusterForNode(node); + auto name = node->name(); + if (cluster) { + std::string pat = "^\\.cluster\\.(\\d+|none)"; + std::regex idPattern(pat); + std::smatch matched; + if (std::regex_search(name, matched, idPattern)) { + auto m = matched.str(1); + if (m == "none") { + // Prefer not to cluster + VLOG(1) << name << " : Default annotated cluster id -1"; + cluster->set_annotated_id(-1); + } + else { + try { + int id = std::stoi(m); + cluster->set_annotated_id(id); + VLOG(1) << name << " : Set annotated cluster id " << m; + } + catch (...) { + VLOG(1) << name << " : Invalid cluster id: " << m; + } + } + } + else { + VLOG(1) << "Not matched: " << name << " pattern is " << pat; + } + } + else { + VLOG(1) << name << ": Not initially clustered"; + } + } + return absl::OkStatus(); +} + bool MarkForCompilationPassImpl::LogNotContractableAndReturnFalse( Cluster* from, Cluster* to, absl::string_view reason) { VLOG(3) << EdgeContractionFailureMsg(from, to, reason); @@ -1569,6 +1624,11 @@ absl::StatusOr MarkForCompilationPassImpl::TryToContractEdge( return false; } + if (debug_options_.annotate_cluster_id && from->annotated_id() != to->annotated_id()) { + return LogNotContractableAndReturnFalse( + from, to, "the two nodes do not have same annotated ids"); + } + TF_ASSIGN_OR_RETURN(bool devices_compatible, AreDevicesCompatible(*from, *to)); if (!devices_compatible) { @@ -1962,6 +2022,7 @@ absl::Status MarkForCompilationPass::Run( debug_options.min_cluster_size = flags->tf_xla_min_cluster_size; debug_options.fuel = GetPointerToFuel(flags->tf_xla_clustering_fuel); debug_options.dump_graphs = flags->tf_xla_clustering_debug; + debug_options.annotate_cluster_id = flags->tf_xla_annotate_cluster_id; return MarkForCompilation(options, debug_options); } @@ -1981,6 +2042,7 @@ absl::Status MarkForCompilationPass::RunForTest( debug_options.min_cluster_size = flags->tf_xla_min_cluster_size; debug_options.fuel = GetPointerToFuel(flags->tf_xla_clustering_fuel); debug_options.dump_graphs = flags->tf_xla_clustering_debug; + debug_options.annotate_cluster_id = flags->tf_xla_annotate_cluster_id; return MarkForCompilation(options, debug_options); } From 03d5b4b5473c81fee16c98204ec11a8c393e5336 Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Sat, 21 Mar 2026 11:31:14 +0000 Subject: [PATCH 03/16] Add dynamic batch runtime plumbing for CPU execution Adds the CPU runtime plumbing needed to carry the dynamic batch value through execution and lowers the supporting runtime pieces. Co-authored-by: Steven Varoumas Co-authored-by: Utku Saglam Co-authored-by: Jinyun (Joey) Ye Co-authored-by: Muteages <67578152+Muteages@users.noreply.github.com> Co-authored-by: Guillermo Callaghan --- .../jit/encapsulate_subgraphs_pass.cc | 36 ++- tensorflow/compiler/jit/flags.cc | 8 + tensorflow/compiler/jit/flags.h | 6 + tensorflow/compiler/jit/kernels/xla_ops.cc | 82 ++++++- .../compiler/jit/mark_for_compilation_pass.cc | 224 ++++++++++++++++++ .../compiler/tf2xla/kernels/reduction_ops.cc | 16 +- tensorflow/compiler/tf2xla/xla_argument.h | 3 + tensorflow/compiler/tf2xla/xla_compiler.cc | 18 ++ tensorflow/core/graph/subgraph.cc | 9 + tensorflow/core/kernels/BUILD | 10 + tensorflow/core/kernels/batch_size_resource.h | 16 ++ tensorflow/core/kernels/function_ops.cc | 35 ++- tensorflow/core/kernels/function_ops.h | 2 + .../cpu/codegen/kernel_api_ir_builder.cc | 35 ++- .../cpu/codegen/kernel_api_ir_builder.h | 4 + .../xla/xla/backends/cpu/runtime/kernel.cc | 39 +-- .../xla/xla/backends/cpu/runtime/kernel.h | 17 +- .../xla/backends/cpu/runtime/kernel_c_api.h | 1 + .../xla/backends/cpu/runtime/kernel_thunk.cc | 8 +- .../xla/xla/backends/cpu/runtime/thunk.h | 1 + third_party/xla/xla/executable_run_options.h | 8 + .../xla/xla/hlo/builder/xla_builder.cc | 8 + third_party/xla/xla/hlo/builder/xla_builder.h | 2 + third_party/xla/xla/service/cpu/BUILD | 10 + .../xla/xla/service/cpu/cpu_compiler.cc | 2 +- .../xla/xla/service/cpu/cpu_executable.cc | 52 +++- .../xla/xla/service/cpu/dot_op_emitter.cc | 13 +- .../cpu/executable_run_options_offset.cc | 28 +++ .../cpu/executable_run_options_offset.h | 8 + third_party/xla/xla/service/cpu/ir_emitter.cc | 27 +++ third_party/xla/xla/service/cpu/ir_emitter.h | 1 + .../xla/xla/service/cpu/ir_emitter2.cc | 27 +++ third_party/xla/xla/service/cpu/ir_emitter2.h | 3 + .../xla/service/cpu/parallel_loop_emitter.cc | 8 +- .../xla/xla/service/cpu/thunk_emitter.cc | 17 ++ .../xla/xla/service/cpu/thunk_emitter.h | 3 + .../xla/xla/service/elemental_ir_emitter.cc | 81 +++++-- third_party/xla/xla/service/llvm_ir/BUILD | 1 + .../xla/xla/service/llvm_ir/ir_array.cc | 39 ++- .../xla/xla/service/llvm_ir/llvm_loop.cc | 46 +++- .../xla/xla/service/llvm_ir/llvm_loop.h | 8 +- .../xla/xla/service/llvm_ir/llvm_util.cc | 62 +++++ .../xla/xla/service/llvm_ir/llvm_util.h | 2 + .../xla/xla/service/llvm_ir/loop_emitter.cc | 29 +++ third_party/xla/xla/shape.h | 3 + 45 files changed, 962 insertions(+), 96 deletions(-) create mode 100644 tensorflow/core/kernels/batch_size_resource.h create mode 100644 third_party/xla/xla/service/cpu/executable_run_options_offset.cc create mode 100644 third_party/xla/xla/service/cpu/executable_run_options_offset.h diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 3e8a43ce08ed58..1bcae2384a8058 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -57,6 +57,12 @@ limitations under the License. namespace tensorflow { +static const absl::flat_hash_set kFailingOps = { + "Pad", + "Where", + // add more here +}; + const char* const kXlaCompiledKernelAttr = "_XlaCompiledKernel"; const char* const kXlaNumConstantArgsAttr = "_XlaNumConstantArgs"; const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs"; @@ -470,6 +476,24 @@ absl::Status Encapsulator::Subgraph::RecordArg( DataType dtype = edge->dst()->input_type(edge->dst_input()); builder.Attr("T", dtype); builder.Attr("index", arg_index); + AttrSlice attrs = src_node->attrs(); + auto shape_attr = attrs.FindByString("_output_shapes"); + if (shape_attr && shape_attr->has_list()) { + const TensorShapeProto& shape = shape_attr->list().shape(src_slot); + if (shape.dim_size() >= 1 && shape.dim(0).size() == -1) { + VLOG(1) << "Found Dynamic dimension in " << src_node->name() << ":" + << src_slot; + builder.Attr("_is_batch", true); + } + } else { + // if cluster argument is the real argument. + auto build_attr = attrs.FindByString("_is_batch"); + if (build_attr) { + VLOG(1) << "Found Dynamic dimension in " << src_node->name() << ":" + << src_slot; + builder.Attr("_is_batch", true); + } + } absl::Status s = builder.Finalize(&arg_def); if (!s.ok()) return s; @@ -1143,6 +1167,14 @@ static absl::Status RenumberArguments(Graph* graph, return absl::OkStatus(); } +static bool SubgraphHasFailingOps(const Graph& g) { + for (Node* n : g.op_nodes()) { + if (n->IsRetval()) continue; + if (kFailingOps.contains(n->def().op())) return true; + } + return false; +} + absl::Status EncapsulateSubgraphsPass::Run( const GraphOptimizationPassOptions& options) { VLOG(1) << "EncapsulateSubgraphsPass::Run"; @@ -1289,8 +1321,8 @@ absl::Status EncapsulateSubgraphsPass::Run( // TODO(phawkins): add a forward is-constant analysis, similarly split // outputs into host-memory constants and device-memory non-constants. - - AddNodeAttr(kXlaCompiledKernelAttr, true, node); + bool compile_enabled = !SubgraphHasFailingOps(**subgraph); + AddNodeAttr(kXlaCompiledKernelAttr, compile_enabled, node); AddNodeAttr(kXlaNumConstantArgsAttr, num_consts, node); AddNodeAttr(kXlaNumResourceArgsAttr, num_resources, node); return absl::OkStatus(); diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index d898298f0d1c85..b3b383d7530ee5 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -114,6 +114,9 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector* flag_list) { "Operators whose name starting with .cluster.{id} will likely" "to be clustered together if the ids are the same number. " ".cluster.none will not be clustered with those having numbered id"), + Flag("tf_xla_cluster_parallel", + &mark_for_compilation_flags->tf_xla_cluster_parallel, + "Split parallel compute subgraph info different clusters"), Flag( "tf_xla_ops_to_cluster", &mark_for_compilation_flags->tf_xla_ops_to_cluster, @@ -161,6 +164,9 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector* flag_list) { &mark_for_compilation_flags->tf_xla_deterministic_cluster_names, "Causes the function names assigned by auto clustering to be " "deterministic from run to run."), + Flag("tf_xla_enable_dynamic_sizes", + &mark_for_compilation_flags->tf_xla_enable_dynamic_sizes, + "Enable dynamic sizes support."), Flag("tf_xla_persistent_cache_directory", &mark_for_compilation_flags->tf_xla_persistent_cache_directory, "If non-empty, JIT-compiled executables are saved to and loaded " @@ -239,6 +245,7 @@ void AllocateAndParseFlags() { mark_for_compilation_flags->tf_xla_max_cluster_size = std::numeric_limits::max(); mark_for_compilation_flags->tf_xla_annotate_cluster_id = false; + mark_for_compilation_flags->tf_xla_cluster_parallel = false; mark_for_compilation_flags->tf_xla_clustering_debug = false; mark_for_compilation_flags->tf_xla_cpu_global_jit = false; mark_for_compilation_flags->tf_xla_clustering_fuel = @@ -248,6 +255,7 @@ void AllocateAndParseFlags() { mark_for_compilation_flags ->tf_xla_disable_resource_variable_safety_checks_for_debugging = false; mark_for_compilation_flags->tf_xla_deterministic_cluster_names = false; + mark_for_compilation_flags->tf_xla_enable_dynamic_sizes = false; mark_for_compilation_flags->tf_xla_persistent_cache_directory = ""; mark_for_compilation_flags->tf_xla_persistent_cache_device_types = ""; mark_for_compilation_flags->tf_xla_persistent_cache_read_only = false; diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index e9e14fbabdd7aa..ebb6b7a518146e 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -65,6 +65,9 @@ struct MarkForCompilationPassFlags { // Enable operator name to influence clustering decision bool tf_xla_annotate_cluster_id; + // Split parallel compute subgraph info different clusters + bool tf_xla_cluster_parallel; + // If non-empty, limit XLA clustering to the following TF operations. string tf_xla_ops_to_cluster; @@ -96,6 +99,9 @@ struct MarkForCompilationPassFlags { // so that they remain stable from run to run of auto clusteing. bool tf_xla_deterministic_cluster_names; + // If true enables support of dynamic sizes. + bool tf_xla_enable_dynamic_sizes; + // If non-empty, JIT-compiled executables are saved to and loaded from the // specified file system directory path. std::string tf_xla_persistent_cache_directory; diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 468b85280e2a47..2a8b1110639b49 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -71,6 +71,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/batch_size_resource.h" #include "tensorflow/core/lib/monitoring/counter.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/errors.h" @@ -427,9 +428,52 @@ absl::Status CompileToLocalExecutable( XlaCompiler::CompileOptions compile_options = GenerateCompileOptions(has_ref_vars, may_alias_resource_update); - return xla_device_compiler->CompileIfNeeded( - options, function, args, compile_options, compile_mode, profiler, - compilation_result, executable); + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); + if (flags->tf_xla_enable_dynamic_sizes) { + // Rewriting the argument with the magic number if they have dynamic + // dimension, detecting dynamic dimension via _is_batch attr in the + // argument. + std::vector norm_args(args.begin(), args.end()); + constexpr int64_t kMagicBound = 977; + if (options.flib_def != nullptr) { + const FunctionDef* fdef = options.flib_def->Find(function.name()); + if (fdef != nullptr) { + for (const auto& kv : fdef->arg_attr()) { + int arg_index = kv.first; + const auto& attr_map = kv.second.attr(); + auto it = attr_map.find("_is_batch"); + + const AttrValue& v = it->second; + if (it == attr_map.end()) continue; + norm_args[arg_index].dynamic_dim = 0; + } + } + } + for (int i = 0; i < norm_args.size(); ++i) { + auto& arg = norm_args[i]; + // argument rewrite. + if (arg.dynamic_dim == 0) { + TensorShape& shp = std::get(arg.shape); + int64_t old = shp.dim_size(0); + shp.set_dim(0, kMagicBound); + } + // constant argument rewrite otherwise it still store the incoming batch + // request. + if (arg.kind == XlaCompiler::Argument::kConstant) { + auto flat = arg.constant_value.flat(); + int32 old_batch = flat(0); + flat(0) = static_cast(kMagicBound); + } + } + + return xla_device_compiler->CompileIfNeeded( + options, function, norm_args, compile_options, compile_mode, profiler, + compilation_result, executable); + } else { + return xla_device_compiler->CompileIfNeeded( + options, function, args, compile_options, compile_mode, profiler, + compilation_result, executable); + } } absl::Status GetUpdatedVariables( @@ -802,14 +846,24 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { ctx, function_, has_ref_vars_, platform_info_, args, compile_mode, /*may_alias_resource_update=*/false, &client, &kernel, &executable); } + if (compile_mode != DeviceCompileMode::kLazy || status.code() != error::UNIMPLEMENTED) { - OP_REQUIRES_OK(ctx, status); + if ((status != OkStatus()) && + (status.code() != error::UNIMPLEMENTED) && + (compile_mode == DeviceCompileMode::kLazy)) { + // We set the error to error::UNIMPLEMENTED so it falls in the + // conditions of the if to fall back to TensorFlow function call + status = tensorflow::errors::Unimplemented(status.ToString()); + } else { + OP_REQUIRES_OK(ctx, status); + } } if (status.code() == error::UNIMPLEMENTED) { - LOG(WARNING) << "Compilation failed:" << status - << ". Falling back to TF function call."; + LOG(WARNING) << "[HUAWEI] Compilation of the cluster failed with:"; + LOG(WARNING) << "[HUAWEI] " << status; + LOG(WARNING) << "[HUAWEI] Falling back to TF function call.\n"; BroadcastOptimizationRemark( XlaOptimizationRemark::UNIMPLEMENTED_OPERATION, status.ToString()) @@ -817,6 +871,8 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { executable = nullptr; pjrt_executable = nullptr; mutex_lock guard(cannot_compile_cluster_mu_); + // TODO: decide if we want to set this flag to true, as we may want to + // allow the cluster to try to compile again later in time. cannot_compile_cluster_ = true; } } @@ -953,6 +1009,20 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { xla::ExecutableRunOptions run_options; + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); + if (flags->tf_xla_enable_dynamic_sizes) { + BatchSizeResource* bsr = nullptr; + ScopedStepContainer* step_container = ctx->step_container(); + + OP_REQUIRES_OK(ctx, step_container->Lookup( + ctx->resource_manager(), BatchSizeResourceName, &bsr)); + + run_options.set_batch_size(bsr->GetBatchSize()); + VLOG(1) << "run_options.batch_size is set to: " + << run_options.batch_size() << ". step_id: " << ctx->step_id(); + bsr->Unref(); + } + // Host callbacks used for HLO send/recv. xla::SendDeviceMemoryFunction send_function = GetSendDeviceMemoryFunction(ctx, key); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 584be983432789..7ed57392c0ce9f 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -30,6 +30,7 @@ limitations under the License. #include #include #include +#include #include "absl/base/call_once.h" #include "absl/container/flat_hash_map.h" @@ -110,6 +111,8 @@ class MarkForCompilationPassImpl { // stable from run to rum. bool deterministic_cluster_names; + bool enable_dynamic_sizes; + int max_cluster_size; int min_cluster_size; @@ -127,6 +130,8 @@ class MarkForCompilationPassImpl { // Enable models to influcence clustering with operator names int annotate_cluster_id; + + bool enable_cluster_parallel; }; MarkForCompilationPassImpl(DebugOptions debug_options, Graph* graph, @@ -251,9 +256,12 @@ class MarkForCompilationPassImpl { int annotated_id() const { return annotated_id_; } void set_annotated_id(int id) { annotated_id_ = id; } + int chain_id() const {return chain_id_;} + void set_chain_id(int id) {chain_id_ = id;} private: int annotated_id_ = -1; + int chain_id_ = -1; int cluster_size_ = 1; int cycles_graph_node_id_; int effective_cluster_size_; @@ -326,6 +334,14 @@ class MarkForCompilationPassImpl { } absl::Status AssignAnnotatedClusterIDs(); + void collectInputNodes(std::set &path_nodes); + void collectMergeNodes(const std::vector& nodeSet, + std::set &merger_nodes); + void collectPathNodes(Node* start, std::set &path_nodes, + std::set& merger_nodes); + std::map> collectParallelNode( + const std::vector& nodeSet); + absl::Status AssignParallelChains(); // Tries to contract the edge from cluster `from` to cluster `to`. Returns // true if successful. @@ -702,6 +718,9 @@ absl::StatusOr MarkForCompilationPassImpl::Initialize() { if (debug_options_.annotate_cluster_id) { TF_RETURN_IF_ERROR(AssignAnnotatedClusterIDs()); } + if (debug_options_.enable_cluster_parallel) { + TF_RETURN_IF_ERROR(AssignParallelChains()); + } return true; } @@ -1044,6 +1063,7 @@ absl::Status MarkForCompilationPassImpl::CreateClusters() { // trouble. if (cluster->effective_cluster_size() >= debug_options_.min_cluster_size || + cluster->chain_id() != -1 || cluster->has_functional_control_flow() || cluster->is_xla_compile_attr_true()) { string& name = cluster_names[cluster->cycles_graph_node_id()]; @@ -1609,8 +1629,208 @@ bool MarkForCompilationPassImpl::LogNotContractableAndReturnFalse( return false; } +void MarkForCompilationPassImpl::collectInputNodes(std::set &path_nodes) { + std::unordered_map out_degree_count; + + // 4. Initialize the queue and add nodes from path_nodes + std::queue queue; + for (auto node : path_nodes) { + queue.push(node); + } + + // 5. BFS search + while (!queue.empty()) { + auto u = queue.front(); + queue.pop(); + + // Traverse all predecessor nodes + for (const Edge* e : u->in_edges()) { + Node* p = e->src(); + if (path_nodes.find(p) != path_nodes.end() || + !IsCompilationCandidate(p)) { + continue; + } + if (out_degree_count.count(p) == 0) { + // Initialize the out-degree count for the node + out_degree_count[p] = p->out_edges().size(); + } + // Decrease the out-degree count for the predecessor node + out_degree_count[p] -= 1; + + // If the predecessor node's out-degree count is 0 and not in path_nodes, + // add it to path_nodes and queue + if (out_degree_count[p] == 0) { + path_nodes.insert(p); + queue.push(p); + VLOG(3) << p->DebugString(); + } + } + } +} + +void MarkForCompilationPassImpl::collectMergeNodes( + const std::vector& nodeSet, std::set &merger_nodes) { + // 1. Collect the number of nodeSet that can reach each node + std::map> reach_map; + for (Node* start : nodeSet) { + std::set visited; + std::vector stack = {start}; + while (!stack.empty()) { + Node* cur = stack.back(); + stack.pop_back(); + if (visited.count(cur)) continue; + visited.insert(cur); + for (const Edge* e : cur->out_edges()) { + Node* next = e->dst(); + if (!visited.count(next)) + stack.push_back(next); + } + } + reach_map[start] = std::move(visited); + } + + // 2. Determine the merger node + std::map node_reach_count; + std::set all_nodes; + for (const auto& kv : reach_map) { + for (Node* n : kv.second) { + node_reach_count[n]++; + all_nodes.insert(n); + } + } + for (Node* n : all_nodes) { + // Condition 1: Reached by multiple sources + if (node_reach_count[n] >= 2) merger_nodes.insert(n); + // Condition 2: No output edges + if (n->out_edges().empty()) merger_nodes.insert(n); + } +} + +void MarkForCompilationPassImpl::collectPathNodes( + Node* start, std::set &path_nodes, std::set& merger_nodes) { + std::vector stack = {start}; + std::set visited; + + while (!stack.empty()) { + Node* cur = stack.back(); + stack.pop_back(); + if (visited.count(cur) || !IsCompilationCandidate(cur)) { + continue; + } + visited.insert(cur); + + // Stop search met the merger node + if (merger_nodes.count(cur)) { + if (cur->out_edges().empty()) path_nodes.insert(cur); + continue; + } + path_nodes.insert(cur); + + for (const Edge* e : cur->out_edges()) { + Node* next = e->dst(); + if (!visited.count(next)) { + stack.push_back(next); + } + } + } +} + +// collectParallelNode +// Search the serial merger nodes based on the parallel matmul starting points +// Search along the output edge to get the boundary from start to all merger points +// Search along the input edge to get the entire parallel computation graph +std::map> +MarkForCompilationPassImpl::collectParallelNode( + const std::vector& nodeSet) { + std::set merger_nodes; + collectMergeNodes(nodeSet, merger_nodes); + + // Collect path nodes + std::map> result; + for (Node* start : nodeSet) { + VLOG(4) << "Search parallel graph form node: " << start->DebugString(); + std::set path_nodes; + + // Search along output edge form start to merger nodes + collectPathNodes(start, path_nodes, merger_nodes); + + VLOG(4) << "Collect path nodes:"; + for (auto node : path_nodes) { + VLOG(4) << node->type_string(); + } + + VLOG(4) << "Collect input nodes:"; + // search along input edge + collectInputNodes(path_nodes); + + result[start] = std::vector(path_nodes.begin(), path_nodes.end()); + } + return result; +} + +// Collect parallel matmuls as input nodes into nodeSet +// Use collectParallelNode to get the parallel subgraph +// Mark parallel nodes change ID +absl::Status MarkForCompilationPassImpl::AssignParallelChains() { + VLOG(4) << "Run AssignParallelChains"; + // Record the matmuls that can be paralleled + std::vector> parallel_matmuls; + int minParallelMatmulNum = 2; + int next_chain_id = 0; + // Collect matmul nodes with shared input to parallel_matmuls + for (Node* node : graph_->nodes()) { + if (node->out_edges().size() < 2) continue; + std::vector matmul_nodes; + for (const Edge* e : node->out_edges()) { + if (e->IsControlEdge()) continue; + Node* succ = e->dst(); + VLOG(4) << "Find matmul node: " << succ->type_string() << " : " << succ->DebugString(); + if (succ->type_string() == "MatMul") + matmul_nodes.push_back(succ); + } + if (matmul_nodes.size() >= minParallelMatmulNum) + parallel_matmuls.push_back(matmul_nodes); + } + + for (auto matmul_nodes : parallel_matmuls) { + VLOG(4) << "Process matmul nodes: Total " << matmul_nodes.size() + << " sub matmuls"; + bool visited = false; + for (auto matmul : matmul_nodes) { + VLOG(4) << matmul->name(); + Cluster* cluster = GetClusterForNode(matmul); + if (!cluster || cluster->chain_id() != -1) { + visited = true; + break; + } + } + if (visited) { + VLOG(4) << "stop collect: has visited matmul"; + continue; + } + + std::map> subgraphMap = + collectParallelNode(matmul_nodes); + for (auto it : subgraphMap) { + auto nodeSet = it.second; + VLOG(4) << "One of Parallel sub-graph is: "; + for (auto node : nodeSet) { + VLOG(4) << node->DebugString(); + Cluster* cluster = GetClusterForNode(node); + cluster->set_chain_id(next_chain_id); + } + next_chain_id++; + } + } + return absl::OkStatus(); +} + absl::StatusOr MarkForCompilationPassImpl::TryToContractEdge( Cluster* from, Cluster* to) { + if (from->chain_id() != to->chain_id()) { + return LogNotContractableAndReturnFalse( + from, to, "nodes are in different parallel chains"); + } DCHECK(from->deadness_predicate().has_value() == to->deadness_predicate().has_value()); if (from->deadness_predicate() != to->deadness_predicate()) { @@ -2018,11 +2238,14 @@ absl::Status MarkForCompilationPass::Run( debug_options.ignore_xla_compile_attr = false; debug_options.deterministic_cluster_names = flags->tf_xla_deterministic_cluster_names; + debug_options.enable_dynamic_sizes = + flags->tf_xla_enable_dynamic_sizes; debug_options.max_cluster_size = flags->tf_xla_max_cluster_size; debug_options.min_cluster_size = flags->tf_xla_min_cluster_size; debug_options.fuel = GetPointerToFuel(flags->tf_xla_clustering_fuel); debug_options.dump_graphs = flags->tf_xla_clustering_debug; debug_options.annotate_cluster_id = flags->tf_xla_annotate_cluster_id; + debug_options.enable_cluster_parallel = flags->tf_xla_cluster_parallel; return MarkForCompilation(options, debug_options); } @@ -2038,6 +2261,7 @@ absl::Status MarkForCompilationPass::RunForTest( flags->tf_xla_disable_resource_variable_safety_checks_for_debugging; debug_options.ignore_xla_compile_attr = true; debug_options.deterministic_cluster_names = deterministic_cluster_names; + debug_options.enable_dynamic_sizes = false; debug_options.max_cluster_size = flags->tf_xla_max_cluster_size; debug_options.min_cluster_size = flags->tf_xla_min_cluster_size; debug_options.fuel = GetPointerToFuel(flags->tf_xla_clustering_fuel); diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index 5f911018c244b5..be6b6e077656b1 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/status/status.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/hlo/builder/lib/constants.h" @@ -147,10 +148,21 @@ class MeanOp : public XlaReductionOp { xla::XlaOp result = reduce_output; xla::Shape bounded_shape = builder->GetShape(input).value(); int64_t divisor_value = bounded_shape.dimensions(dimensions_to_reduce[0]); - auto divisor = xla::GetDimensionSize(input, dimensions_to_reduce[0]); + xla::XlaOp divisor; + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); + if (flags->tf_xla_enable_dynamic_sizes && dimensions_to_reduce[0] == 0) { + divisor = xla::GetOuterBatchValue(input); + } else { + divisor = xla::GetDimensionSize(input, dimensions_to_reduce[0]); + } for (int i = 1; i < dimensions_to_reduce.size(); i++) { int64_t size_value = bounded_shape.dimensions(dimensions_to_reduce[i]); - auto size = xla::GetDimensionSize(input, dimensions_to_reduce[i]); + xla::XlaOp size; + if (flags->tf_xla_enable_dynamic_sizes && dimensions_to_reduce[i] == 0) { + size = xla::GetOuterBatchValue(input); + } else { + size = xla::GetDimensionSize(input, dimensions_to_reduce[i]); + } if (size_value * divisor_value > std::numeric_limits::max()) { result = result / xla::ConvertElementType(divisor, xla_reduction_type_); divisor_value = size_value; diff --git a/tensorflow/compiler/tf2xla/xla_argument.h b/tensorflow/compiler/tf2xla/xla_argument.h index 9e2eccd29b1885..86811d3416eb46 100644 --- a/tensorflow/compiler/tf2xla/xla_argument.h +++ b/tensorflow/compiler/tf2xla/xla_argument.h @@ -53,6 +53,9 @@ struct XlaArgument { kTensorList, }; + //To keep dynamic dim as an attribute of the argument. + int64_t dynamic_dim = -1; + Kind kind = kInvalid; // The type of the argument. If the argument is a resource, this diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index b7cff00c8a0bfe..65df8e1fc15c73 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -840,6 +840,8 @@ absl::Status XlaCompiler::CompileFunction( } } else { TensorShape tensor_shape = std::get(args[i].shape); + AttrSlice n_attrs = fbody->arg_nodes[i]->attrs(); + std::vector output_shapes; fbody->arg_nodes[i]->ClearAttr("_output_shapes"); fbody->arg_nodes[i]->AddAttr("_output_shapes", std::vector{tensor_shape}); @@ -1194,6 +1196,13 @@ absl::Status XlaCompiler::BuildArguments( xla::OpMetadata arg_metadata; arg_metadata.set_op_name(arg.node_name); + + if (arg.dynamic_dim==0) { + // Encode dynamic dims as a string in op_type, so it appears in HLO metadata. + arg_metadata.set_op_type( + absl::StrCat("XLA_Arg_dyn[",arg.dynamic_dim, "]")); + } + builder->SetOneShotOpMetadata(arg_metadata); arg_handles[i] = xla::GetTupleElement(tuple, i); } @@ -1203,6 +1212,15 @@ absl::Status XlaCompiler::BuildArguments( xla::XlaScopedShardingAssignment assign_sharding( builder, it == arg_shardings.end() ? std::optional() : it->second); + auto& arg = args[input_to_args->at(i)]; + xla::OpMetadata arg_metadata; + arg_metadata.set_op_name(arg.node_name); + if (arg.dynamic_dim==0) { + arg_metadata.set_op_type( + absl::StrCat("XLA_Arg_dyn[",arg.dynamic_dim, "]")); + } + builder->SetOneShotOpMetadata(arg_metadata); + if (is_entry_computation) { // Add an entry to is_same_across_replicas for every leaf buffer. std::vector is_same_across_replicas( diff --git a/tensorflow/core/graph/subgraph.cc b/tensorflow/core/graph/subgraph.cc index bb47f37ef7fbe3..0d66e1f868c2a5 100644 --- a/tensorflow/core/graph/subgraph.cc +++ b/tensorflow/core/graph/subgraph.cc @@ -79,6 +79,15 @@ absl::Status FeedInputs( TF_RETURN_IF_ERROR( feed_rewrites[i]->AddNode(g, {n, id.second}, &feed_node)); + // Set an attribute in _Arg node to indicate it has a batch dimension + auto node_attrs = n->attrs(); + const AttrValue* shape_attr = node_attrs.FindByString("_output_shapes"); + if (shape_attr && shape_attr->has_list()) { + const TensorShapeProto& shape = shape_attr->list().shape(0); + if (shape.dim_size() >= 1 && shape.dim(0).size() == -1) { + feed_node->AddAttr("_is_batch", true); + } + } // Update name_index (*name_index)[feed_node->name()] = feed_node; // Duplicate control edges aren't allowed, but feed_node was *just* created diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index d25f6736317f61..8b71e9a4c0ad7c 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -3072,6 +3072,7 @@ tf_kernel_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/core/kernels:batch_size_resource", ], ) @@ -7972,6 +7973,15 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "batch_size_resource", + srcs = ["batch_size_resource.h"], + deps = [ + "//tensorflow/core:framework", + ], +) + + # For a more maintainable build this target should not exist and the headers # should be split into the existing cc_library targets, but this change was # automatically done so that we can remove long standing issues and complexity diff --git a/tensorflow/core/kernels/batch_size_resource.h b/tensorflow/core/kernels/batch_size_resource.h new file mode 100644 index 00000000000000..e457b65e3293e9 --- /dev/null +++ b/tensorflow/core/kernels/batch_size_resource.h @@ -0,0 +1,16 @@ +#include "tensorflow/core/framework/resource_mgr.h" + +namespace tensorflow { +const string BatchSizeResourceName = "BatchSizeResource_"; +class BatchSizeResource : public ResourceBase { + public: + ~BatchSizeResource() override { + VLOG(1) << "BatchSizeResource destroyed with batch size: " << batch_size_; + } + string DebugString() const override { return BatchSizeResourceName; } + void SetBatchSize(size_t s) { batch_size_ = s; } + size_t GetBatchSize() { return batch_size_; } + private: + size_t batch_size_ = 0; +}; +} diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index 864855de1d69f6..6ea77f78f04797 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/kernels/batch_size_resource.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/util/device_name_utils.h" @@ -59,16 +60,42 @@ void ArgOp::Compute(OpKernelContext* ctx) { } }; + Tensor t; if (frame->CanConsumeArg(index_)) { - Tensor val; - frame->ConsumeArg(index_, &val); - OP_REQUIRES_OK(ctx, validate_type(val)); - ctx->set_output(0, std::move(val)); + frame->ConsumeArg(index_, &t); + OP_REQUIRES_OK(ctx, validate_type(t)); + ctx->set_output(0, std::move(t)); + val = &t; } else { OP_REQUIRES_OK(ctx, frame->GetArg(index_, &val)); OP_REQUIRES_OK(ctx, validate_type(*val)); ctx->set_output(0, *val); } + if (is_batch_) { + BatchSizeResource* bsr = nullptr; + ScopedStepContainer* step_container = ctx->step_container(); + + OP_REQUIRES_OK(ctx, step_container->LookupOrCreate( + ctx->resource_manager(), BatchSizeResourceName, &bsr, + [](BatchSizeResource** ret) -> Status { + *ret = new BatchSizeResource(); + return OkStatus(); + })); + + const int64_t batch_size = val->dim_size(0); + if (bsr->GetBatchSize() == 0) { + bsr->SetBatchSize(batch_size); + VLOG(1) << "Set batch_size from 0 to " << batch_size + << ". step_id: " << ctx->step_id(); + } else if (bsr->GetBatchSize() != batch_size) { + VLOG(1) << "Warning: Set batch_size from " << bsr->GetBatchSize() + << ". step_id: " << ctx->step_id(); + bsr->SetBatchSize(batch_size); + } else { + VLOG(1) << "batch_size already set to " << batch_size; + } + bsr->Unref(); + } } RetvalOp::RetvalOp(OpKernelConstruction* ctx) : OpKernel(ctx) { diff --git a/tensorflow/core/kernels/function_ops.h b/tensorflow/core/kernels/function_ops.h index 552e1e6c515e3b..c9e8fdcd894e25 100644 --- a/tensorflow/core/kernels/function_ops.h +++ b/tensorflow/core/kernels/function_ops.h @@ -38,6 +38,7 @@ class ArgOp : public OpKernel { private: int index_; DataType dtype_; + bool is_batch_; ArgOp(const ArgOp&) = delete; void operator=(const ArgOp&) = delete; @@ -54,6 +55,7 @@ class RetvalOp : public OpKernel { private: int index_; DataType dtype_; + bool is_batch_; RetvalOp(const RetvalOp&) = delete; void operator=(const RetvalOp&) = delete; diff --git a/third_party/xla/xla/backends/cpu/codegen/kernel_api_ir_builder.cc b/third_party/xla/xla/backends/cpu/codegen/kernel_api_ir_builder.cc index 2ce99c4b93a878..30f64818209cd4 100644 --- a/third_party/xla/xla/backends/cpu/codegen/kernel_api_ir_builder.cc +++ b/third_party/xla/xla/backends/cpu/codegen/kernel_api_ir_builder.cc @@ -147,7 +147,7 @@ llvm::StructType* KernelCallFrameTy(llvm::LLVMContext& ctx) { llvm::PointerType* ptr = llvm::PointerType::getUnqual(ctx); llvm::IntegerType* i64 = llvm::IntegerType::getInt64Ty(ctx); return llvm::StructType::create("XLA_CPU_KernelCallFrame", ptr, ptr, i64, - ptr); + ptr, i64); } llvm::FunctionType* KernelFunctionTy(llvm::LLVMContext& ctx) { @@ -316,6 +316,34 @@ auto KernelApiIrBuilder::EmitKernelPrototype( return EmitKernelPrototype(module, name, arguments, results); } +llvm::Value* KernelApiIrBuilder::EmitGetBatchDim(llvm::IRBuilderBase& builder, + llvm::Value* call_frame) { + llvm::LLVMContext& ctx = builder.getContext(); + llvm::Type* ptr = llvm::PointerType::get(ctx, 0); + llvm::IntegerType* i64 = llvm::IntegerType::getInt64Ty(ctx); + llvm::Value* bdim_gep = + builder.CreateStructGEP(call_frame_ty_, call_frame, 4, "bdim_gep"); + llvm::Value* bdim_value = builder.CreateLoad(i64, bdim_gep, "bdim_value"); + +#if defined(PRINT_BATCHSIZE) + // Print batch size + llvm::Function* function = builder.GetInsertBlock()->getParent(); + llvm::Module* module = function->getParent(); + llvm::FunctionType* printfType = llvm::FunctionType::get( + builder.getInt32Ty(), llvm::PointerType::get(builder.getInt8Ty(), 0), + true); + llvm::Value* funcNameStr = + builder.CreateGlobalStringPtr(function->getName()); + llvm::FunctionCallee printfFunc = + module->getOrInsertFunction("printf", printfType); + llvm::Value* formatStr = + builder.CreateGlobalStringPtr("Function: %s, Batch size is : %d!\n"); + builder.CreateCall(printfFunc, {formatStr, funcNameStr, bdim_value}); +#endif + + return bdim_value; +} + auto KernelApiIrBuilder::EmitKernelPrototype( llvm::Module& module, absl::string_view name, absl::Span arguments, @@ -394,6 +422,8 @@ auto KernelApiIrBuilder::EmitKernelPrototype( ir_results.push_back(std::move(ir_result)); } + EmitGetBatchDim(b, call_frame); + // Return null pointer to signal success as we do not support error handling // in the compiled host kernel. llvm::BasicBlock* return_block = @@ -503,7 +533,8 @@ llvm_ir::IrArray KernelApiIrBuilder::EmitKernelArgument( const llvm::DataLayout& data_layout = llvm_module->getDataLayout(); int64_t pointer_size = data_layout.getTypeStoreSize(builder.getPtrTy()); int64_t byte_size = ShapeUtil::ByteSizeOf(shape, pointer_size); - llvm_ir::SetDereferenceableMetadataForLoad(data, byte_size); + if (shape.outer_multiplier() < 0) + llvm_ir::SetDereferenceableMetadataForLoad(data,byte_size); // All buffers pointers passed to host kernels are expected to be invariant // over the whole program. Note the metadata is attached only to loading diff --git a/third_party/xla/xla/backends/cpu/codegen/kernel_api_ir_builder.h b/third_party/xla/xla/backends/cpu/codegen/kernel_api_ir_builder.h index 58ffc5bc8594e3..08af675fcb6c37 100644 --- a/third_party/xla/xla/backends/cpu/codegen/kernel_api_ir_builder.h +++ b/third_party/xla/xla/backends/cpu/codegen/kernel_api_ir_builder.h @@ -147,9 +147,13 @@ class KernelApiIrBuilder { llvm_ir::IrArray EmitKernelArgument(llvm::IRBuilderBase& builder, llvm::Value* call_frame, int64_t index, const Shape& shape); + llvm::Function* EmitKernelFunction(llvm::Module& module, absl::string_view name); + llvm::Value* EmitGetBatchDim(llvm::IRBuilderBase& builder, + llvm::Value* call_frame); + private: llvm::LLVMContext& context_; diff --git a/third_party/xla/xla/backends/cpu/runtime/kernel.cc b/third_party/xla/xla/backends/cpu/runtime/kernel.cc index ac1a5d7181ec52..668e10f0a0dd09 100644 --- a/third_party/xla/xla/backends/cpu/runtime/kernel.cc +++ b/third_party/xla/xla/backends/cpu/runtime/kernel.cc @@ -64,7 +64,7 @@ template class Kernel::ParallelTask { public: ParallelTask(XLA_CPU_Kernel* kernel, Kernel::ThreadDim thread_dims, - absl::Span args); + size_t batch_size, absl::Span args); // Invokes a host kernel for a given task index. absl::Status operator()(size_t task_index) const; @@ -76,6 +76,7 @@ class Kernel::ParallelTask { XLA_CPU_Kernel* kernel_; XLA_CPU_KernelThreadDim thread_dims_; + size_t batch_size_; absl::InlinedVector args_; size_t num_tasks_; @@ -88,11 +89,13 @@ class Kernel::ParallelTask { template Kernel::ParallelTask::ParallelTask( XLA_CPU_Kernel* kernel, Kernel::ThreadDim thread_dims, + size_t batch_size, absl::Span args) : kernel_(kernel), thread_dims_({thread_dims.x, thread_dims.y, thread_dims.z}), args_(args.begin(), args.end()), num_tasks_(thread_dims_.x * thread_dims_.y * thread_dims_.z), + batch_size_(batch_size), stride_z_(thread_dims_.y * thread_dims_.x), stride_y_(thread_dims_.x) {} @@ -103,7 +106,7 @@ absl::Status Kernel::ParallelTask::operator()( XLA_CPU_KernelThread kernel_thread = Delinearize(task_index); XLA_CPU_KernelCallFrame call_frame = {&thread_dims_, &kernel_thread, - args_.size(), args_.data()}; + args_.size(), args_.data(), batch_size_}; XLA_CPU_KernelError* error = (*kernel_)(&call_frame); @@ -138,12 +141,12 @@ Kernel::Kernel(unsigned arity, XLA_CPU_Kernel* kernel) kernel_(function_->kernel()), arity_(arity) {} -absl::Status Kernel::Launch(const ThreadDim& thread_dims, +absl::Status Kernel::Launch(const ThreadDim& thread_dims, size_t batch_size, absl::Span buffers) const { - return Launch(thread_dims, ConvertBuffersToKernelArgs(buffers)); + return Launch(thread_dims, batch_size, ConvertBuffersToKernelArgs(buffers)); } -absl::Status Kernel::Launch(const ThreadDim& thread_dims, +absl::Status Kernel::Launch(const ThreadDim& thread_dims, size_t batch_size, absl::Span args) const { XLA_CPU_KernelThreadDim kernel_thread_dims = { thread_dims.x, @@ -156,8 +159,9 @@ absl::Status Kernel::Launch(const ThreadDim& thread_dims, for (uint64_t x = 0; x < thread_dims.x; ++x) { XLA_CPU_KernelThread kernel_thread = {x, y, z}; - XLA_CPU_KernelCallFrame call_frame = { - &kernel_thread_dims, &kernel_thread, args.size(), args.data()}; + XLA_CPU_KernelCallFrame call_frame = {&kernel_thread_dims, + &kernel_thread, args.size(), + args.data(), batch_size}; XLA_CPU_KernelError* error = (*kernel_)(&call_frame); @@ -172,20 +176,23 @@ absl::Status Kernel::Launch(const ThreadDim& thread_dims, } tsl::AsyncValueRef Kernel::Launch( - const ThreadDim& thread_dims, absl::Span buffers, + const ThreadDim& thread_dims, size_t batch_size, + absl::Span buffers, const Eigen::ThreadPoolDevice* device) const { - return Launch(thread_dims, ConvertBuffersToKernelArgs(buffers), device); + return Launch(thread_dims, batch_size, ConvertBuffersToKernelArgs(buffers), + device); } tsl::AsyncValueRef Kernel::Launch( - const ThreadDim& thread_dims, absl::Span args, + const ThreadDim& thread_dims, size_t batch_size, + absl::Span args, const Eigen::ThreadPoolDevice* device) const { size_t num_tasks = thread_dims.x * thread_dims.y * thread_dims.z; CHECK_GT(num_tasks, 0) << "Number of tasks must be positive"; // Crash Ok // Short-circuit launch with a single task and run it in the caller thread. if (ABSL_PREDICT_TRUE(num_tasks == 1)) { - absl::Status launched = Launch(thread_dims, args); + absl::Status launched = Launch(thread_dims, batch_size, args); return ABSL_PREDICT_TRUE(launched.ok()) ? OkLaunchEvent() : tsl::MakeErrorAsyncValueRef(std::move(launched)); @@ -197,11 +204,13 @@ tsl::AsyncValueRef Kernel::Launch( std::numeric_limits::max()); if (ABSL_PREDICT_TRUE(thread_dims.y == 1 && thread_dims.z == 1)) { - return Worker::Parallelize(device, num_workers, num_tasks, - ParallelTask(kernel_, thread_dims, args)); + return Worker::Parallelize( + device, num_workers, num_tasks, + ParallelTask(kernel_, thread_dims, batch_size, args)); } else { - return Worker::Parallelize(device, num_workers, num_tasks, - ParallelTask(kernel_, thread_dims, args)); + return Worker::Parallelize( + device, num_workers, num_tasks, + ParallelTask(kernel_, thread_dims, batch_size, args)); } } diff --git a/third_party/xla/xla/backends/cpu/runtime/kernel.h b/third_party/xla/xla/backends/cpu/runtime/kernel.h index 98010905322f67..0d0465af154464 100644 --- a/third_party/xla/xla/backends/cpu/runtime/kernel.h +++ b/third_party/xla/xla/backends/cpu/runtime/kernel.h @@ -73,13 +73,13 @@ class Kernel { // Calls the kernel once in the caller thread for a thread dim (0,0,0). // This is a fast path for small host kernels that have just one thread. - absl::Status CallOnce(absl::Span args) const; + absl::Status CallOnce(absl::Span args, size_t batch_size) const; // Launches the kernel on the current thread by iterating over all threads in // `thread_dims` and calling the kernel function. - absl::Status Launch(const ThreadDim& thread_dims, + absl::Status Launch(const ThreadDim& thread_dims, size_t batch_size, absl::Span buffers) const; - absl::Status Launch(const ThreadDim& thread_dims, + absl::Status Launch(const ThreadDim& thread_dims, size_t batch_size, absl::Span args) const; // Launches the kernel by iterating over all threads in `thread_dims` and @@ -89,10 +89,12 @@ class Kernel { // Async value returned in constructed state and the caller can access it to // get the number of tasks that are expected to be completed. tsl::AsyncValueRef Launch( - const ThreadDim& thread_dims, absl::Span buffers, + const ThreadDim& thread_dims, size_t batch_size, + absl::Span buffers, const Eigen::ThreadPoolDevice* device) const; tsl::AsyncValueRef Launch( - const ThreadDim& thread_dims, absl::Span args, + const ThreadDim& thread_dims, size_t batch_size, + absl::Span args, const Eigen::ThreadPoolDevice* device) const; // For host platform, we assume that a core is a thread, and we can run at @@ -123,12 +125,13 @@ class Kernel { }; inline ABSL_ATTRIBUTE_ALWAYS_INLINE absl::Status Kernel::CallOnce( - absl::Span args) const { + absl::Span args, + size_t batch_size) const { constexpr XLA_CPU_KernelThreadDim kernel_thread_dims = {1, 1, 1}; constexpr XLA_CPU_KernelThread kernel_thread = {1, 1, 1}; XLA_CPU_KernelCallFrame call_frame = {&kernel_thread_dims, &kernel_thread, - args.size(), args.data()}; + args.size(), args.data(), batch_size}; XLA_CPU_KernelError* error = (*kernel_)(&call_frame); diff --git a/third_party/xla/xla/backends/cpu/runtime/kernel_c_api.h b/third_party/xla/xla/backends/cpu/runtime/kernel_c_api.h index cbe0568506385d..cf8016759ed217 100644 --- a/third_party/xla/xla/backends/cpu/runtime/kernel_c_api.h +++ b/third_party/xla/xla/backends/cpu/runtime/kernel_c_api.h @@ -72,6 +72,7 @@ typedef struct XLA_CPU_KernelCallFrame { size_t num_args; const XLA_CPU_KernelArg* args; + size_t batch_size; } XLA_CPU_KernelCallFrame; // Error reporting for host kernels. NULL means success. diff --git a/third_party/xla/xla/backends/cpu/runtime/kernel_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/kernel_thunk.cc index 90c15a09bd677d..7a1c0ec70c449b 100644 --- a/third_party/xla/xla/backends/cpu/runtime/kernel_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/kernel_thunk.cc @@ -223,7 +223,7 @@ KernelThunk::ExecuteInternal( // Use a fast path if kernel called just once. if (ABSL_PREDICT_TRUE(call_once_)) { - TF_RETURN_IF_ERROR(kernel->CallOnce(kernel_args)); + TF_RETURN_IF_ERROR(kernel->CallOnce(kernel_args, params.batch_size)); return OkExecuteEvent(); } @@ -231,10 +231,12 @@ KernelThunk::ExecuteInternal( // by scheduling tasks into it. HostKernel launch completion will // automatically signal KernelThunk execute completion. if (ABSL_PREDICT_TRUE(params.intra_op_threadpool)) { - return kernel->Launch(thread_dim_, kernel_args, params.intra_op_threadpool); + return kernel->Launch(thread_dim_, params.batch_size, kernel_args, + params.intra_op_threadpool); } - TF_RETURN_IF_ERROR(kernel->Launch(thread_dim_, kernel_args)); + TF_RETURN_IF_ERROR( + kernel->Launch(thread_dim_, params.batch_size, kernel_args)); return OkExecuteEvent(); } diff --git a/third_party/xla/xla/backends/cpu/runtime/thunk.h b/third_party/xla/xla/backends/cpu/runtime/thunk.h index d4a56ae88fa55b..5b7042ccb2803f 100644 --- a/third_party/xla/xla/backends/cpu/runtime/thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/thunk.h @@ -259,6 +259,7 @@ class Thunk { TaskRunner* task_runner = nullptr; CollectiveExecuteParams* collective_params = nullptr; CustomCallExecuteParams* custom_call_params = nullptr; + int64_t batch_size = 0; ExecuteSession session = ExecuteSession(ExecuteSession::kMaxWorkers, ExecuteSession::kSplitThreshold); }; diff --git a/third_party/xla/xla/executable_run_options.h b/third_party/xla/xla/executable_run_options.h index b377e91670efb9..211932ffcbba98 100644 --- a/third_party/xla/xla/executable_run_options.h +++ b/third_party/xla/xla/executable_run_options.h @@ -195,6 +195,13 @@ class ExecutableRunOptions { return *this; } + ExecutableRunOptions& set_batch_size(int64_t batch_size) { + batch_size_ = batch_size; + return *this; + } + + int64_t batch_size() const { return batch_size_; } + int32_t launch_id() const { return launch_id_; } ExecutableRunOptions& set_run_id(RunId id); @@ -265,6 +272,7 @@ class ExecutableRunOptions { ExecutionProfile* execution_profile_ = nullptr; int rng_seed_ = 0; int32_t launch_id_ = 0; + int64_t batch_size_ = 0; stream_executor::Stream* device_to_host_stream_ = nullptr; stream_executor::Stream* host_to_device_stream_ = nullptr; ThenExecuteFunction* then_execute_function_ = nullptr; diff --git a/third_party/xla/xla/hlo/builder/xla_builder.cc b/third_party/xla/xla/hlo/builder/xla_builder.cc index 93d7782de50e03..96000563190d77 100644 --- a/third_party/xla/xla/hlo/builder/xla_builder.cc +++ b/third_party/xla/xla/hlo/builder/xla_builder.cc @@ -6085,6 +6085,14 @@ XlaOp GetDimensionSize(const XlaOp operand, int64_t dimension) { return operand.builder()->GetDimensionSize(operand, dimension); } +XlaOp GetOuterBatchValue(XlaOp operand) { + XlaBuilder* builder = operand.builder(); + return CustomCall(builder, "GetOuterBatchValue", {operand}, + ShapeUtil::MakeShape(S32, {}), "", false, {}, + nullptr, CustomCallSchedule::SCHEDULE_NONE, + CustomCallApiVersion::API_VERSION_ORIGINAL); +} + XlaOp SetDimensionSize(const XlaOp operand, const XlaOp val, int64_t dimension) { return operand.builder()->SetDimensionSize(operand, val, dimension); diff --git a/third_party/xla/xla/hlo/builder/xla_builder.h b/third_party/xla/xla/hlo/builder/xla_builder.h index 8f479090d86b19..b5ede47dd41476 100644 --- a/third_party/xla/xla/hlo/builder/xla_builder.h +++ b/third_party/xla/xla/hlo/builder/xla_builder.h @@ -3055,6 +3055,8 @@ XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean, // array shaped. XlaOp GetDimensionSize(XlaOp operand, int64_t dimension); +XlaOp GetOuterBatchValue(XlaOp operand); + // Sets the size of the given dimension of the operand. The operand must be // array shaped. The result will have the same shape as the operand, but the // given dimension will be dynamic (if not already). diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 24247e60993937..877d55221652af 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -1156,6 +1156,16 @@ cc_library( ], ) +cc_library( + name = "executable_run_options_offset", + srcs = ["executable_run_options_offset.cc"], + hdrs = ["executable_run_options_offset.h"], + copts = tsl_copts(), + deps = [ + "//xla:executable_run_options", + ], +) + cc_library( name = "runtime_conv3d", srcs = ["runtime_conv3d.cc"], diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc index fcffdafcf5e5ee..abcfae5fccda0a 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc @@ -475,7 +475,7 @@ std::unique_ptr> CreateSimplificationPipeline( } // Needs to happen after algebraic simplifier. - pipeline->AddPass(); + // pipeline->AddPass(); // BatchNormExpander can create zero-sized ops, so zero-sized HLO // elimination has to come after that pass. diff --git a/third_party/xla/xla/service/cpu/cpu_executable.cc b/third_party/xla/xla/service/cpu/cpu_executable.cc index 158a4000d27722..e161f760a5da03 100644 --- a/third_party/xla/xla/service/cpu/cpu_executable.cc +++ b/third_party/xla/xla/service/cpu/cpu_executable.cc @@ -163,14 +163,34 @@ static absl::StatusOr MemoryForAllocation( se::DeviceMemoryAllocator* memory_allocator, int device_ordinal) { VLOG(3) << allocation.ToString(); if (allocation.is_entry_computation_parameter()) { - se::DeviceMemoryBase out = arguments[allocation.parameter_number()] + se::DeviceMemoryBase param_mem = arguments[allocation.parameter_number()] .Buffer(allocation.param_shape_index()) .AsDeviceMemoryBase(); - CHECK_LE(allocation.size(), out.size()) - << "Size mismatch on param " << allocation.parameter_number() - << " at shape index " << allocation.param_shape_index().ToString(); - VLOG(3) << "allocation is a parameter"; - return MaybeOwningDeviceMemory{out}; + + const int64_t compiled_bytes = allocation.size(); + const int64_t runtime_bytes = param_mem.size(); + + if(runtime_bytes == 0){ + return MaybeOwningDeviceMemory{param_mem}; + } + + if (compiled_bytes <= runtime_bytes) { + return MaybeOwningDeviceMemory{param_mem}; + } + + //padded owns that device memory + TF_ASSIGN_OR_RETURN( + se::OwningDeviceMemory padded, + memory_allocator->Allocate(device_ordinal, compiled_bytes)); + + void* dst = padded-> opaque(); + void* src = param_mem.opaque(); + + std::memcpy(dst, src, runtime_bytes); + //Fill rest of them with zeros. + std::memset(static_cast(dst) + runtime_bytes, 0, compiled_bytes - runtime_bytes); + return MaybeOwningDeviceMemory{std::move(padded)}; + } else if (allocation.is_constant()) { VLOG(3) << "allocation is a constant"; if (allocation.index() < constants.size()) { @@ -275,11 +295,28 @@ absl::Status CpuExecutable::ExecuteComputeFunction( return absl::OkStatus(); } +void PrintScalars(absl::Span buffers) { + for (int i = 0; i < buffers.size(); ++i) { + const se::DeviceMemoryBase& dmem = buffers[i].AsDeviceMemoryBase(); + if (dmem.opaque() && dmem.size() >= sizeof(int64_t)) { + int64_t val = 0; + std::memcpy(&val, dmem.opaque(), sizeof(val)); + std::cerr << "Buffer " << i << " scalar: " << val << std::endl; + } else { + std::cerr << "Buffer " << i << " empty or too small" << std::endl; + } + } +} + absl::Status CpuExecutable::ExecuteThunks( const ExecutableRunOptions* run_options, absl::Span buffers) { uint64_t start_ns = tsl::Env::Default()->NowNanos(); + #if defined(PRINT_BATCHSIZE) + PrintScalars(buffers); + #endif + size_t profile_counters_size = 0; int64_t* profile_counters = nullptr; @@ -318,7 +355,8 @@ absl::Status CpuExecutable::ExecuteThunks( intra_op_thread_pool, &task_runner, &collective_execute_params, - &custom_call_execute_params}; + &custom_call_execute_params, + run_options->batch_size()}; auto executed_event = thunks_->Execute(execute_params); tsl::BlockUntilReady(executed_event); diff --git a/third_party/xla/xla/service/cpu/dot_op_emitter.cc b/third_party/xla/xla/service/cpu/dot_op_emitter.cc index e2e4f914579dbb..c3f432897567de 100644 --- a/third_party/xla/xla/service/cpu/dot_op_emitter.cc +++ b/third_party/xla/xla/service/cpu/dot_op_emitter.cc @@ -862,12 +862,17 @@ absl::Status DotOpEmitter::EmitCallToRuntime() { std::swap(transpose_lhs, transpose_rhs); } + // We work under the assumption that only M can be dynamic. + int batch_multiplier = target_array_.GetShape().outer_multiplier(); + llvm::Value* m_val = (batch_multiplier > 0) + ? xla::llvm_ir::GetBatchDimByName(b_, batch_multiplier) + : b_->getInt64(mat_mult_dims.m); + b_->CreateCall(matmul_func, {executable_run_options_value_, target_array_.GetBasePointer(), - lhs->GetBasePointer(), rhs->GetBasePointer(), - b_->getInt64(mat_mult_dims.m), b_->getInt64(mat_mult_dims.n), - b_->getInt64(mat_mult_dims.k), b_->getInt32(transpose_lhs), - b_->getInt32(transpose_rhs)}); + lhs->GetBasePointer(), rhs->GetBasePointer(), m_val, + b_->getInt64(mat_mult_dims.n), b_->getInt64(mat_mult_dims.k), + b_->getInt32(transpose_lhs), b_->getInt32(transpose_rhs)}); return absl::OkStatus(); } diff --git a/third_party/xla/xla/service/cpu/executable_run_options_offset.cc b/third_party/xla/xla/service/cpu/executable_run_options_offset.cc new file mode 100644 index 00000000000000..2d84ddfa2ea352 --- /dev/null +++ b/third_party/xla/xla/service/cpu/executable_run_options_offset.cc @@ -0,0 +1,28 @@ +#include "executable_run_options_offset.h" +#include "xla/executable_run_options.h" + +namespace xla::cpu { + +// Friend-injection trick to get a pointer-to-private-member for the *real* +// xla::ExecutableRunOptions::batch_size_. +template +struct Linker { + friend constexpr typename Tag::type get_offset(Tag) { return Ptr; } +}; + +struct BatchSizeTag { + using type = int64_t xla::ExecutableRunOptions::*; + friend constexpr type get_offset(BatchSizeTag); +}; + +// Instantiate template to expose &ExecutableRunOptions::batch_size_. +template struct Linker; + +size_t ExecutableRunOptionsBatchSizeOffset() { + auto ptr = get_offset(BatchSizeTag{}); + // Compute offset in bytes from null pointer. + return reinterpret_cast( + &(reinterpret_cast(0)->*ptr)); +} + +} // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/executable_run_options_offset.h b/third_party/xla/xla/service/cpu/executable_run_options_offset.h new file mode 100644 index 00000000000000..3e702454e76ec5 --- /dev/null +++ b/third_party/xla/xla/service/cpu/executable_run_options_offset.h @@ -0,0 +1,8 @@ +#pragma once +#include + +namespace xla::cpu{ + + size_t ExecutableRunOptionsBatchSizeOffset(); + +} diff --git a/third_party/xla/xla/service/cpu/ir_emitter.cc b/third_party/xla/xla/service/cpu/ir_emitter.cc index feca6552d243f8..cbcea0cb2d1f14 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter.cc @@ -2358,6 +2358,29 @@ absl::Status IrEmitter::HandleSliceToDynamic(HloInstruction* hlo) { return EmitSliceToDynamic(hlo, source_arrays, target_array); } +absl::Status IrEmitter::HandleOuterBatchValue(HloInstruction* hlo) { + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo)); + + llvm_ir::IrArray out_array = GetIrArrayFor(hlo); + int64_t multiplier = hlo->operand(0)->shape().outer_multiplier(); + + if (multiplier <= 0) { + LOG(ERROR) << "Invalid outer multiplier for GetOuterBatchValue: " + << multiplier; + return absl::InvalidArgumentError( + "Invalid outer multiplier for GetOuterBatchValue"); + } + llvm::Value* bdim_value = llvm_ir::GetBatchDimByName(b(), multiplier); + + llvm_ir::IrArray::Index out_index(/*multidimensional_index=*/{}, hlo->shape(), + b()->getInt32Ty()); + + llvm::Value* out_ptr = out_array.EmitArrayElementAddress(out_index, b()); + b()->CreateStore(bdim_value, out_ptr); + + return absl::OkStatus(); +} + absl::Status IrEmitter::HandlePadToStatic(HloInstruction* hlo) { TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo)); @@ -2806,6 +2829,10 @@ absl::Status IrEmitter::HandleOneDnnSoftmax(HloInstruction* custom_call) { #endif // INTEL_MKL absl::Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { + if (custom_call->custom_call_target() == "GetOuterBatchValue") { + return HandleOuterBatchValue(custom_call); + } + if (custom_call->custom_call_target() == "PadToStatic") { return HandlePadToStatic(custom_call); } diff --git a/third_party/xla/xla/service/cpu/ir_emitter.h b/third_party/xla/xla/service/cpu/ir_emitter.h index 40f54d2f4bff97..0bd7fcc6b377a9 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.h +++ b/third_party/xla/xla/service/cpu/ir_emitter.h @@ -332,6 +332,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, private: absl::Status HandleSliceToDynamic(HloInstruction* hlo); + absl::Status HandleOuterBatchValue(HloInstruction* hlo); absl::Status HandlePadToStatic(HloInstruction* hlo); absl::Status HandleTopK(HloInstruction* hlo) override; absl::Status HandleAllReduceSingleReplica(HloInstruction* crs); diff --git a/third_party/xla/xla/service/cpu/ir_emitter2.cc b/third_party/xla/xla/service/cpu/ir_emitter2.cc index bce2108bb87572..927d43f9d42fc5 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter2.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter2.cc @@ -195,6 +195,33 @@ absl::StatusOr IrEmitter2::EmitPadHostKernel( KernelInfo(std::move(kernel_prototype), se::BlockDim(), se::ThreadDim())); } +absl::StatusOr +IrEmitter2::EmitGetOuterBatchValueHostKernel(const HloInstruction* getBatch) { + VLOG(2) << "Emit GetOuterBatchValue host kernel: " << getBatch->name(); + + TF_ASSIGN_OR_RETURN(KernelPrototype kernel_prototype, + EmitKernelPrototype(getBatch)); + llvm_ir::IrArray operand_array = kernel_prototype.arguments[0]; + llvm_ir::IrArray output_array = kernel_prototype.results[0]; + int64_t multiplier = getBatch->operand(0)->shape().outer_multiplier(); + if (multiplier <= 0) { + LOG(ERROR) << "Invalid outer multiplier for GetOuterBatchValue: " + << multiplier; + return absl::InvalidArgumentError( + "Invalid outer multiplier for GetOuterBatchValue"); + } + llvm::IRBuilder<> b(module_->getContext()); + b.SetInsertPoint(kernel_prototype.function->getEntryBlock().getTerminator()); + llvm::Value* bdim_value = llvm_ir::GetBatchDimByName(&b, multiplier); + llvm_ir::IrArray::Index output_index(/*multidimensional_index=*/{}, + getBatch->shape(), b.getInt32Ty()); + llvm::Value* output_ptr = + output_array.EmitArrayElementAddress(output_index, &b); + b.CreateStore(bdim_value, output_ptr); + return kernels_.emplace_back( + KernelInfo(std::move(kernel_prototype), se::BlockDim(), se::ThreadDim())); +} + absl::StatusOr IrEmitter2::EmitFusionHostKernel( const HloFusionInstruction* fusion) { VLOG(2) << "Emit fusion host kernel: " << fusion->name(); diff --git a/third_party/xla/xla/service/cpu/ir_emitter2.h b/third_party/xla/xla/service/cpu/ir_emitter2.h index e720e06f37642c..2a0b98d5856332 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter2.h +++ b/third_party/xla/xla/service/cpu/ir_emitter2.h @@ -113,6 +113,9 @@ class IrEmitter2 { // Emits a host kernel for the pad instruction. absl::StatusOr EmitPadHostKernel(const HloInstruction* pad); + absl::StatusOr EmitGetOuterBatchValueHostKernel( + const HloInstruction* getBatch); + // Emits a host kernel for the given fusion instruction. absl::StatusOr EmitFusionHostKernel( const HloFusionInstruction* fusion); diff --git a/third_party/xla/xla/service/cpu/parallel_loop_emitter.cc b/third_party/xla/xla/service/cpu/parallel_loop_emitter.cc index 2bfffd88df937e..adabc35f90c23f 100644 --- a/third_party/xla/xla/service/cpu/parallel_loop_emitter.cc +++ b/third_party/xla/xla/service/cpu/parallel_loop_emitter.cc @@ -61,6 +61,8 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, // performance with a large improvement in compile time. auto unroll_mode = (i == 0) ? llvm_ir::UnrollMode::kDefaultUnroll : llvm_ir::UnrollMode::kNoUnroll; + bool is_batch_dim = (dimension == 0) && shape_.outer_multiplier() > 0; + if (bounds_index < dynamic_loop_bounds_->size()) { // Emit dynamic loop bounds for this dimension. Dynamic loop bounds // are read from ir function dynamic loop bounds argument. @@ -69,14 +71,16 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, std::unique_ptr loop = loop_nest.AddLoop( /*suffix=*/absl::StrFormat("dim.%d", dimension), start_index, - end_index, unroll_mode); + end_index, unroll_mode, /*prevent_vectorization*/ false, + is_batch_dim); array_multi_index[dimension] = loop->GetIndVarValue(); } else { // Emit static loop bounds for this dimension. std::unique_ptr loop = loop_nest.AddLoop( /*start_index=*/0, /*end_index=*/shape_.dimensions(dimension), - /*suffix=*/absl::StrFormat("dim.%d", dimension), unroll_mode); + /*suffix=*/absl::StrFormat("dim.%d", dimension), unroll_mode, + /*prevent_vectorization*/ false, is_batch_dim); array_multi_index[dimension] = loop->GetIndVarValue(); } } diff --git a/third_party/xla/xla/service/cpu/thunk_emitter.cc b/third_party/xla/xla/service/cpu/thunk_emitter.cc index 9079ed8a16d1a8..b6c32ecf8d4eb0 100644 --- a/third_party/xla/xla/service/cpu/thunk_emitter.cc +++ b/third_party/xla/xla/service/cpu/thunk_emitter.cc @@ -1086,6 +1086,8 @@ absl::StatusOr ThunkEmitter::EmitCustomCallThunk( return EmitTopKThunk(custom_call); } else if (custom_call_target == "SliceToDynamic") { return EmitSliceToDynamicThunk(instruction); + } else if (custom_call_target == "GetOuterBatchValue") { + return EmitGetOuterBatchValueThunk(instruction); } // Check the API version. @@ -1128,6 +1130,21 @@ absl::StatusOr ThunkEmitter::EmitSliceToDynamicThunk( /*min_alignment=*/cpu_function_runtime::MinAlign()); } +absl::StatusOr ThunkEmitter::EmitGetOuterBatchValueThunk( + const HloInstruction* instruction) { + VLOG(2) << "Handling GetOuterBatchValue for instruction: " + << instruction->ToString(); + const HloCustomCallInstruction* custom_call = + Cast(instruction); + TF_ASSIGN_OR_RETURN( + auto kernel, ir_emitter_.EmitGetOuterBatchValueHostKernel(custom_call)); + TF_ASSIGN_OR_RETURN(auto result_buffer, + GetHostKernelAllocationSlices(instruction)); + return MakeKernelThunkSequence( + instruction, result_buffer, kernel, + /*min_alignment=*/cpu_function_runtime::MinAlign()); +} + absl::StatusOr ThunkEmitter::EmitSliceThunk( const HloInstruction* instruction) { // TODO(ezhulenev): Consider implementing slice operations as separate diff --git a/third_party/xla/xla/service/cpu/thunk_emitter.h b/third_party/xla/xla/service/cpu/thunk_emitter.h index 9cdc8ae3981680..9949fef0329371 100644 --- a/third_party/xla/xla/service/cpu/thunk_emitter.h +++ b/third_party/xla/xla/service/cpu/thunk_emitter.h @@ -183,6 +183,9 @@ class ThunkEmitter { absl::StatusOr EmitSliceToDynamicThunk( const HloInstruction* instruction); + absl::StatusOr EmitGetOuterBatchValueThunk( + const HloInstruction* instruction); + absl::StatusOr EmitTopKThunk( const HloCustomCallInstruction* custom_call); diff --git a/third_party/xla/xla/service/elemental_ir_emitter.cc b/third_party/xla/xla/service/elemental_ir_emitter.cc index 89391a3bdd1dd5..2963fd4a03d5b1 100644 --- a/third_party/xla/xla/service/elemental_ir_emitter.cc +++ b/third_party/xla/xla/service/elemental_ir_emitter.cc @@ -3278,51 +3278,56 @@ absl::StatusOr ElementalIrEmitter::EmitElementalConcatenate( } // We use bisection to select the input operand. - int64_t current_offset = 0; + int64_t coffset = 0; + llvm::Value* current_offset = source_index.GetConstantWithIndexType(0); // Offset for every operand. - std::vector> cases; + std::vector> cases; cases.reserve(hlo->operand_count()); for (const HloInstruction* operand : hlo->operands()) { cases.emplace_back(current_offset, operand); - current_offset += operand->shape().dimensions(concat_dim); + llvm::Value* cdim = source_index.GetConstantWithIndexType( + operand->shape().dimensions(concat_dim)); + if (concat_dim == 0 && operand->shape().outer_multiplier() > 0) { + cdim = llvm_ir::GetBatchDimByName(b_, operand->shape().outer_multiplier()); + } + current_offset = b_->CreateAdd(current_offset, cdim, "current_offset"); + coffset += operand->shape().dimensions(concat_dim); } - CHECK_EQ(current_offset, hlo->shape().dimensions(concat_dim)); + CHECK_EQ(coffset, hlo->shape().dimensions(concat_dim)); std::function> operands)> + absl::Span> operands)> emit_tree = - [&](absl::Span> + [&](absl::Span> operands) { llvm::IRBuilder<>::InsertPointGuard guard(*b_); size_t mid = operands.size() / 2; - const std::pair& pivot = + const std::pair& pivot = operands[mid]; llvm::BasicBlock* block = llvm_ir::CreateBasicBlock( exit_block, - absl::StrCat("concatenate.pivot.", pivot.first, "."), b_); + absl::StrCat("concatenate.pivot."), b_); b_->SetInsertPoint(block); // If there's only one element we're done. The range is contiguous // so we can just jump to the block for it. if (operands.size() == 1) { - const std::pair& operand = + const std::pair& operand = operands.back(); int64_t operand_id = to_unique_operand_id[operand.second]; source_index_phis[operand_id]->addIncoming( - source_index.GetConstantWithIndexType(operand.first), + operand.first, b_->GetInsertBlock()); b_->CreateBr(emit_operand_blocks[operand_id]); return block; } // Take the middle element and recurse. - llvm::Constant* pivot_const = llvm::ConstantInt::get( - source_index[concat_dim]->getType(), pivot.first); llvm::Value* comp = - b_->CreateICmpULT(source_index[concat_dim], pivot_const); + b_->CreateICmpULT(source_index[concat_dim], pivot.first); llvm::BasicBlock* left_block = emit_tree(operands.subspan(0, mid)); llvm::BasicBlock* right_block = emit_tree(operands.subspan(mid)); @@ -3616,11 +3621,21 @@ absl::StatusOr ElementalIrEmitter::EmitElementalPad( "in_bounds"); multi_index[i] = SDiv(multi_index[i], index_typed_const(pad_dim.interior_padding() + 1)); - in_bounds = - And(in_bounds, - ICmpSLT(multi_index[i], - index_typed_const(hlo->operand(0)->shape().dimensions(i))), - "in_bounds"); + + int64_t shape_dim = hlo->operand(0)->shape().dimensions(i); + llvm::Value* bound = index_typed_const(shape_dim); + + int64_t multiplier = + (i == 0) ? hlo->operand(0)->shape().outer_multiplier() : -1; + if (multiplier > 0) { + bound = llvm_ir::GetBatchDimByName(b_, multiplier); + } else if (shape_dim == 977) { + // This should be deleted. + LOG(ERROR) << "Dynamic batch marker: No multiplier for batch dim: " + << hlo->ToString(); + } + + in_bounds = And(in_bounds, ICmpSLT(multi_index[i], bound), "in_bounds"); } // if (in_bounds) { @@ -3687,9 +3702,18 @@ absl::StatusOr ElementalIrEmitter::EmitElementalDot( return llvm::ConstantInt::get(index_type, c); }; + llvm::Value* contracted_bound = index_typed_const(contracted_dim_size); + int64_t multiplier = (lhs_contracting_dim == 0) + ? hlo->operand(0)->shape().outer_multiplier() + : -1; + if (multiplier > 0) { + llvm::Value* bdim_value = llvm_ir::GetBatchDimByName(b_, multiplier); + contracted_bound = bdim_value; + } + std::unique_ptr inner_loop = llvm_ir::ForLoop::EmitForLoop( - IrName(hlo, "inner"), index_typed_const(0), - index_typed_const(contracted_dim_size), index_typed_const(1), b_); + IrName(hlo, "inner"), index_typed_const(0), contracted_bound, + index_typed_const(1), b_); SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(), b_); PrimitiveType primitive_type = hlo->shape().element_type(); @@ -3967,6 +3991,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kSlice: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> absl::StatusOr { + // [Steven] the problem is that slices are represented by integer ranges. + // If these are based on the magic number they are wrong. IrArray::Index sliced_index = index.SourceIndexOfSlice( /*operand_shape=*/hlo->operand(0)->shape(), /*starts=*/hlo->slice_starts(), @@ -4236,11 +4262,22 @@ absl::StatusOr ElementalIrEmitter::EmitElementalReduceWindow( // comparison is equivalent to the unsigned comparison // input_multi_index[i] < bound, as a negative value wraps to a large // positive value. + + int64_t dim_bound = reduce_window->inputs()[0]->shape().dimensions(i); + llvm::Value* shape_bound = index_typed_const(dim_bound); + + int64_t multiplier = + (i == 0) ? reduce_window->inputs()[0]->shape().outer_multiplier() : -1; + + if (multiplier > 0) { + llvm::Value* bdim_value = llvm_ir::GetBatchDimByName(b_, multiplier); + shape_bound = bdim_value; + } + in_bounds = And(in_bounds, ICmpULT(input_multi_index[i], - index_typed_const( - reduce_window->inputs()[0]->shape().dimensions(i)))); + shape_bound)); } llvm_ir::LlvmIfData if_data = diff --git a/third_party/xla/xla/service/llvm_ir/BUILD b/third_party/xla/xla/service/llvm_ir/BUILD index 599060b88eee81..2a665c2d77f3b2 100644 --- a/third_party/xla/xla/service/llvm_ir/BUILD +++ b/third_party/xla/xla/service/llvm_ir/BUILD @@ -84,6 +84,7 @@ cc_library( "//xla/service/cpu:cpu_options", "//xla/tsl/platform:byte_order", "//xla/tsl/platform:logging", + "//xla/service/cpu:executable_run_options_offset", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/log", diff --git a/third_party/xla/xla/service/llvm_ir/ir_array.cc b/third_party/xla/xla/service/llvm_ir/ir_array.cc index 139d537c88778a..389eee09ef9bda 100644 --- a/third_party/xla/xla/service/llvm_ir/ir_array.cc +++ b/third_party/xla/xla/service/llvm_ir/ir_array.cc @@ -559,8 +559,43 @@ llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index, int64_t dimension = LayoutUtil::Major(shape_.layout(), i); gep_indices.push_back(actual_index[dimension]); } - return b->CreateInBoundsGEP(pointee_type_, base_ptr_, gep_indices, - llvm_ir::AsStringRef(name)); + +#define DYN_DIMS +#ifdef DYN_DIMS + + llvm::ArrayType* outerArray = llvm::dyn_cast(pointee_type_); + + CHECK(outerArray) << "Expected outer array type."; + + llvm::Value* gep; + + if (shape_.outer_multiplier() > 0) { + + // Extract the inner array type: [N x T] + llvm::Type* innerArray = outerArray->getElementType(); + + CHECK(innerArray) << "Expected inner array type."; + + // Create a new array type: [0 x [N x T]] + llvm::ArrayType* zeroOuterArray = llvm::ArrayType::get(innerArray, 0); + + llvm::PointerType* newPtrTy = llvm::PointerType::getUnqual(zeroOuterArray); + llvm::Value* castedPtr = b->CreateBitCast(base_ptr_, newPtrTy); + + gep = + b->CreateInBoundsGEP(zeroOuterArray, + castedPtr, + gep_indices, llvm_ir::AsStringRef(name)); + } else { + gep = b->CreateInBoundsGEP(pointee_type_, base_ptr_, gep_indices, + llvm_ir::AsStringRef(name)); + } +#else + auto gep = b->CreateInBoundsGEP(pointee_type_, base_ptr_, gep_indices, + llvm_ir::AsStringRef(name)); +#endif + + return gep; } llvm::Value* IrArray::EmitLinearArrayElementAddress( diff --git a/third_party/xla/xla/service/llvm_ir/llvm_loop.cc b/third_party/xla/xla/service/llvm_ir/llvm_loop.cc index 43ee77ef0e85be..f3159e0aacd04c 100644 --- a/third_party/xla/xla/service/llvm_ir/llvm_loop.cc +++ b/third_party/xla/xla/service/llvm_ir/llvm_loop.cc @@ -37,6 +37,7 @@ limitations under the License. #include "xla/service/llvm_ir/llvm_util.h" #include "xla/shape.h" #include "xla/tsl/platform/logging.h" +#include "llvm/include/llvm/Support/Debug.h" namespace xla { namespace llvm_ir { @@ -189,20 +190,29 @@ std::unique_ptr ForLoopNest::AddLoop(absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, UnrollMode unroll_mode, - bool prevent_vectorization) { + bool prevent_vectorization, + bool is_batch_dim) { return AddLoop(suffix, start_index, end_index, GetConstantWithIndexType(1), - unroll_mode, prevent_vectorization); + unroll_mode, prevent_vectorization, is_batch_dim); } std::unique_ptr ForLoopNest::AddLoop( absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, - llvm::Value* stride, UnrollMode unroll_mode, bool prevent_vectorization) { + llvm::Value* stride, UnrollMode unroll_mode, bool prevent_vectorization, + bool is_batch_dim) { if (inner_loop_body_bb_ != nullptr) { // Create this loop inside the previous one. b_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt()); } + llvm::Value* actual_end = end_index; + if (is_batch_dim) { + // Get batch dim and compare with end_index to use minimum value + llvm::Value* batch_dim = GetBatchDimByName(b_); + actual_end = b_->CreateSelect(b_->CreateICmpULT(end_index, batch_dim), + end_index, batch_dim, "loop_end_min"); + } std::unique_ptr loop(new ForLoop( - /*prefix=*/name_, suffix, start_index, end_index, stride, unroll_mode, + /*prefix=*/name_, suffix, start_index, actual_end, stride, unroll_mode, prevent_vectorization)); loop->Emit(b_); @@ -223,23 +233,30 @@ std::unique_ptr ForLoopNest::AddLoop(int64_t start_index, int64_t end_index, absl::string_view suffix, UnrollMode unroll_mode, - bool prevent_vectorization) { + bool prevent_vectorization, + bool is_batch_dim) { CHECK_LE(start_index, end_index); - return AddLoop(suffix, GetConstantWithIndexType(start_index), - GetConstantWithIndexType(end_index), unroll_mode, - prevent_vectorization); + + llvm::Value* end = is_batch_dim ? GetBatchDimByName(b_) : GetConstantWithIndexType(end_index); + is_batch_dim = false; + return AddLoop(suffix, GetConstantWithIndexType(start_index), end, + unroll_mode, prevent_vectorization, is_batch_dim); } std::unique_ptr ForLoopNest::AddLoop(int64_t start_index, int64_t end_index, int64_t stride, absl::string_view suffix, UnrollMode unroll_mode, - bool prevent_vectorization) { + bool prevent_vectorization, + bool is_batch_dim) { CHECK_LE(start_index, end_index); - return AddLoop(suffix, GetConstantWithIndexType(start_index), - GetConstantWithIndexType(end_index), + + llvm::Value* end = is_batch_dim ? GetBatchDimByName(b_) : GetConstantWithIndexType(end_index); + is_batch_dim = false; + + return AddLoop(suffix, GetConstantWithIndexType(start_index), end, GetConstantWithIndexType(stride), unroll_mode, - prevent_vectorization); + prevent_vectorization, is_batch_dim); } IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape, @@ -255,11 +272,14 @@ std::vector ForLoopNest::AddLoopsForShapeOnDimensions( absl::string_view suffix) { std::vector multi_index(shape.dimensions().size()); for (int64_t dimension : dimensions) { + bool is_batch_dim = (dimension == 0) && shape.outer_multiplier() > 0; std::unique_ptr loop = AddLoop( /*start_index=*/0, /*end_index=*/shape.dimensions(dimension), /*suffix=*/ - llvm_ir::IrName(suffix, absl::StrCat(dimension))); + llvm_ir::IrName(suffix, absl::StrCat(dimension)), + /*unroll_mode=*/llvm_ir::UnrollMode::kDefaultUnroll, + /*prevent_vectorization=*/false, is_batch_dim); multi_index[dimension] = loop->GetIndVarValue(); } return multi_index; diff --git a/third_party/xla/xla/service/llvm_ir/llvm_loop.h b/third_party/xla/xla/service/llvm_ir/llvm_loop.h index 7aa8ce9e32e950..c6cbcafdda6dbb 100644 --- a/third_party/xla/xla/service/llvm_ir/llvm_loop.h +++ b/third_party/xla/xla/service/llvm_ir/llvm_loop.h @@ -201,14 +201,14 @@ class ForLoopNest { absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* stride, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, - bool prevent_vectorization = false); + bool prevent_vectorization = false, bool is_batch_dim = false); // Like the above, except that it defaults to a stride of one. std::unique_ptr AddLoop( absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, - bool prevent_vectorization = false); + bool prevent_vectorization = false, bool is_batch_dim = false); // A convenient wrapper of the other flavor of AddLoop. The given start and // end index are constant. @@ -216,13 +216,13 @@ class ForLoopNest { int64_t start_index, int64_t end_index, int64_t stride, absl::string_view suffix, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, - bool prevent_vectorization = false); + bool prevent_vectorization = false, bool is_batch_dim = false); // Like the above, except that it defaults to a stride of one. std::unique_ptr AddLoop( int64_t start_index, int64_t end_index, absl::string_view suffix, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, - bool prevent_vectorization = false); + bool prevent_vectorization = false, bool is_batch_dim = false); // Add loops to iterate through the indices within the specified // shape. The returned index collects the induction variables of the diff --git a/third_party/xla/xla/service/llvm_ir/llvm_util.cc b/third_party/xla/xla/service/llvm_ir/llvm_util.cc index b1db780705b41f..b337b0dcc5bb32 100644 --- a/third_party/xla/xla/service/llvm_ir/llvm_util.cc +++ b/third_party/xla/xla/service/llvm_ir/llvm_util.cc @@ -80,11 +80,15 @@ limitations under the License. #include "xla/xla_data.pb.h" #include "tsl/profiler/lib/scoped_annotation.h" +#include "xla/service/cpu/executable_run_options_offset.h" + namespace xla { namespace llvm_ir { namespace { +constexpr llvm::StringLiteral kBdimValueName("bdim_value"); + // This works for most llvm / mlir types. This also accepts a const pointer to // objects which have a const print() method. template @@ -811,5 +815,63 @@ void EmitEarlyReturn(llvm::Value* condition, llvm::IRBuilderBase* b, b->SetInsertPoint(continued, continued->getFirstInsertionPt()); } +llvm::Value* GetBatchDimByName(llvm::IRBuilderBase* b, int64_t multiplier) { + llvm::Function* function = b->GetInsertBlock()->getParent(); + llvm::LLVMContext& ctx = b->getContext(); + llvm::IntegerType* i64Type = llvm::IntegerType::getInt64Ty(ctx); + llvm::Value* loadedValue = nullptr; + llvm::Value* bdim_scaled = nullptr; + for (auto& inst : function->getEntryBlock()) { + if (inst.getName() == kBdimValueName) { + loadedValue = &inst; + } + } + if (!loadedValue) { + // if missing, try to materialize it by loading from run_options+offset + llvm::Value* run_options = nullptr; + for (auto& arg : function->args()) { + if (arg.getName() == "run_options") { + run_options = &arg; + break; + } + } + + if (!run_options) { + return nullptr; + } + // Materialize %bdim_value by loading ExecutableRunOptions::batch_size_ from + // the 'run_options' function argument using a known byte offset: + // + // 1) Bitcast 'run_options' to i8* to do byte-address arithmetic. + // 2) GEP by 'off' bytes to reach the batch_size field inside the object. + // 3) Bitcast resulting i8* to i64* and load it. + const int64_t off = + static_cast(xla::cpu::ExecutableRunOptionsBatchSizeOffset()); + + // Insert at the entry block. + llvm::IRBuilder<> entry_builder( + &function->getEntryBlock(), + function->getEntryBlock().getFirstInsertionPt()); + llvm::Type* i8 = entry_builder.getInt8Ty(); + llvm::Value* ro_i8 = entry_builder.CreateBitCast( + run_options, llvm::PointerType::getUnqual(i8), "run_options_i8"); + llvm::Value* off_c = llvm::ConstantInt::get(i64Type, off); + llvm::Value* bdim_ptr_i8 = + entry_builder.CreateInBoundsGEP(i8, ro_i8, off_c, "bdim_ptr_i8"); + llvm::Value* bdim_ptr = entry_builder.CreateBitCast( + bdim_ptr_i8, llvm::PointerType::getUnqual(i64Type), "bdim_ptr"); + loadedValue = entry_builder.CreateLoad(i64Type, bdim_ptr, kBdimValueName); + } + if (multiplier < 1) { + llvm::errs() << "Multiplier is less than 1, this should not happen.\n"; + } else if (multiplier == 1) { + bdim_scaled = loadedValue; + } else { + llvm::ConstantInt* m = llvm::ConstantInt::get(i64Type, multiplier, true); + bdim_scaled = b->CreateMul(loadedValue, m, "bdim_scaled"); + } + return bdim_scaled; +} + } // namespace llvm_ir } // namespace xla diff --git a/third_party/xla/xla/service/llvm_ir/llvm_util.h b/third_party/xla/xla/service/llvm_ir/llvm_util.h index 88c1287d2f236d..732e4400e3d467 100644 --- a/third_party/xla/xla/service/llvm_ir/llvm_util.h +++ b/third_party/xla/xla/service/llvm_ir/llvm_util.h @@ -334,6 +334,8 @@ llvm::BasicBlock* EmitReturnBlock(llvm::IRBuilderBase* b); void EmitEarlyReturn(llvm::Value* condition, llvm::IRBuilderBase* b, llvm::BasicBlock* return_block = nullptr); +llvm::Value* GetBatchDimByName(llvm::IRBuilderBase* b, int64_t multiplier = 1); + } // namespace llvm_ir } // namespace xla diff --git a/third_party/xla/xla/service/llvm_ir/loop_emitter.cc b/third_party/xla/xla/service/llvm_ir/loop_emitter.cc index 13f6a67764b346..b0879c8ed4b3de 100644 --- a/third_party/xla/xla/service/llvm_ir/loop_emitter.cc +++ b/third_party/xla/xla/service/llvm_ir/loop_emitter.cc @@ -28,11 +28,13 @@ limitations under the License. #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "xla/layout_util.h" #include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/llvm_loop.h" +#include "xla/service/llvm_ir/llvm_util.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tsl/platform/errors.h" @@ -183,6 +185,33 @@ std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( ForLoopNest loop_nest(loop_name, b_); + llvm::LLVMContext& ctx = b_->getContext(); + llvm::IntegerType* i64Type = llvm::IntegerType::getInt64Ty(ctx); + + llvm::PointerType* ptr = llvm::PointerType::getUnqual(ctx); + llvm::StructType* callFrameTy = llvm::StructType::create( + "XLA_CPU_KernelArg", ptr, ptr, i64Type, ptr, i64Type); + + std::vector dynamic_dims; + for (auto dim : shape_.dimensions()) { + dynamic_dims.push_back(llvm::ConstantInt::get(i64Type, dim)); + } + + bool dynamic = false; + for (int i = 0; i < shape_.dimensions_size(); i++) { + int64_t multiplier = (i == 0) ? shape_.outer_multiplier() : -1; + if (multiplier > 0) { + dynamic_dims[i] = xla::llvm_ir::GetBatchDimByName(b_, multiplier); + shape_.set_dynamic_dimension(i, true); + dynamic = true; + } + } + + if (dynamic) { + // Assign dynamic batch + dynamic_dims_ = dynamic_dims; + } + IrArray::Index array_index = dynamic_dims_.empty() ? EmitStaticIndex(&loop_nest, index_type) : EmitDynamicIndex(&loop_nest, index_type); diff --git a/third_party/xla/xla/shape.h b/third_party/xla/xla/shape.h index 8453dc17717e11..d9dc14ad4deb41 100644 --- a/third_party/xla/xla/shape.h +++ b/third_party/xla/xla/shape.h @@ -532,7 +532,10 @@ class Shape { return Shape::Hash(std::move(h), s); } + int64_t outer_multiplier() const { return outer_multiplier_; } + void set_outer_multiplier(int64_t m) { outer_multiplier_ = m; } private: + int64_t outer_multiplier_ = -1; friend absl::Status ValidateNonLayoutProperties(const Shape& shape); // Define one state struct for each shape category. Depending on the element From 4fbd7f7a9bd5ca7df8c9f6a0d30064b15a26cec2 Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Sat, 21 Mar 2026 11:31:26 +0000 Subject: [PATCH 04/16] Add batch matcher and batch-aware output handling Adds batch matching support and updates output handling so compiled executions can normalize and restore batch-aware values. Co-authored-by: Steven Varoumas Co-authored-by: Utku Saglam Co-authored-by: Jinyun (Joey) Ye Co-authored-by: Muteages <67578152+Muteages@users.noreply.github.com> Co-authored-by: Guillermo Callaghan --- tensorflow/compiler/jit/BUILD | 12 ++ .../jit/device_compilation_profiler.cc | 24 ++- tensorflow/compiler/jit/device_compiler.h | 12 +- tensorflow/compiler/jit/flags.cc | 6 + tensorflow/compiler/jit/flags.h | 5 + tensorflow/compiler/jit/kernels/BUILD | 1 + tensorflow/compiler/jit/kernels/xla_ops.cc | 59 ++++-- tensorflow/compiler/jit/xla_batch_matcher.cc | 185 ++++++++++++++++++ tensorflow/compiler/jit/xla_batch_matcher.h | 40 ++++ tensorflow/compiler/jit/xla_launch_util.cc | 25 ++- tensorflow/core/framework/BUILD | 2 + .../batch_size_resource.h | 0 .../core/grappler/optimizers/remapper.cc | 14 ++ tensorflow/core/kernels/BUILD | 10 - tensorflow/core/kernels/function_ops.cc | 9 +- third_party/xla/xla/debug_options_flags.cc | 10 + .../xla/xla/service/cpu/cpu_executable.cc | 8 +- .../xla/xla/service/layout_assignment.cc | 2 +- third_party/xla/xla/shape.cc | 6 + third_party/xla/xla/shape.h | 5 + third_party/xla/xla/shape_util.cc | 11 +- third_party/xla/xla/xla.proto | 4 +- third_party/xla/xla/xla_data.proto | 2 + 23 files changed, 411 insertions(+), 41 deletions(-) create mode 100644 tensorflow/compiler/jit/xla_batch_matcher.cc create mode 100644 tensorflow/compiler/jit/xla_batch_matcher.h rename tensorflow/core/{kernels => framework}/batch_size_resource.h (100%) diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 39f93d17aa2932..89a05cb80eba8f 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -780,6 +780,8 @@ cc_library( ":flags_headers", ":tf_graph_to_hlo_compiler", ":xla_compile_util", + ":xla_batch_matcher", + "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core:framework", "//tensorflow/core:framework_lite", @@ -2005,3 +2007,13 @@ tf_cuda_cc_test( "@local_xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", ], ) + +cc_library( + name = "xla_batch_matcher", + srcs = ["xla_batch_matcher.cc"], + hdrs = ["xla_batch_matcher.h"], + deps = [ + "//tensorflow/core/platform:logging", + "@local_xla//xla:debug_options_flags", + ], +) \ No newline at end of file diff --git a/tensorflow/compiler/jit/device_compilation_profiler.cc b/tensorflow/compiler/jit/device_compilation_profiler.cc index 5e1b3b26e8ecb5..f8a742ccc01148 100644 --- a/tensorflow/compiler/jit/device_compilation_profiler.cc +++ b/tensorflow/compiler/jit/device_compilation_profiler.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/strings/str_cat.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/xla_activity.pb.h" #include "tensorflow/compiler/jit/xla_activity_listener.h" #include "tensorflow/core/framework/attr_value.pb.h" @@ -32,9 +33,30 @@ limitations under the License. namespace tensorflow { namespace { bool ShouldBeMegamorphic(int64_t compile_count, int64_t execution_count) { - const int64_t kCompileThreshold = 10; + int64_t kCompileThreshold = 10; const int64_t kMinExecutionsPerCompile = 50; + int64_t tf_xla_threshold_for_megamorphic = + GetMarkForCompilationPassFlags()->tf_xla_threshold_for_megamorphic; + + // Negative values other that -1 cannot be used + if (tf_xla_threshold_for_megamorphic < -1) { + LOG(FATAL) << "The value for the tf_xla_threshold_for_megamorphic flag " + << "is out of range.\n" + << "Allowed ranges are (-1) to " + << std::numeric_limits::max() + << " got " << tf_xla_threshold_for_megamorphic << "."; + } + + // -1: setting clusters as Megamorphic is disabled + // 0 Default behaviour in Tensorflow + // Any other number sets the compilation threshold + if (tf_xla_threshold_for_megamorphic == -1) { + return false; + } else if (tf_xla_threshold_for_megamorphic > 0) { + kCompileThreshold = tf_xla_threshold_for_megamorphic; + } + // This heuristic is trying to capture the following property: have we sunk a // certain minimum amount of compile time into the cluster that didn't quite // "pay off"? diff --git a/tensorflow/compiler/jit/device_compiler.h b/tensorflow/compiler/jit/device_compiler.h index 34b22033129b96..28e27c696e788a 100644 --- a/tensorflow/compiler/jit/device_compiler.h +++ b/tensorflow/compiler/jit/device_compiler.h @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/tf_graph_to_hlo_compiler.h" #include "tensorflow/compiler/jit/xla_compile_util.h" +#include "tensorflow/compiler/jit/xla_batch_matcher.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/core/framework/metrics.h" #include "tensorflow/core/framework/op_kernel.h" @@ -125,7 +126,8 @@ class DeviceCompiler : public ResourceBase { DeviceCompilerClient* compiler_client() { return compiler_client_.get(); } - + XlaBatchMatcher* xla_batch_matcher() { return xla_batch_matcher_.get(); } + string DebugString() const override; private: @@ -177,6 +179,9 @@ class DeviceCompiler : public ResourceBase { // Pool of threads for asynchronous compilations. std::unique_ptr async_compiler_threads_; + // Specified dynamic batch padding values. + std::unique_ptr xla_batch_matcher_; + mutex cluster_mutexes_mu_; absl::flat_hash_map, DeviceCompilationClusterSignature::Hash> @@ -225,6 +230,11 @@ DeviceCompiler::DeviceCompiler( async_compiler_threads_ = std::make_unique( tensorflow::Env::Default(), "async_compiler_threads", kNumAsyncDeviceCompilerThreads); + + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); + if (flags->tf_xla_enable_dynamic_sizes) { + xla_batch_matcher_ = std::make_unique(); + } } template diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index b3b383d7530ee5..5d8a1904ab49a2 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -187,6 +187,11 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector* flag_list) { &mark_for_compilation_flags->tf_xla_persistent_cache_prefix, "Specifies the persistance cache prefix. Default is " "\"xla_compile_cache\""), + Flag("tf_xla_threshold_for_megamorphic", + &mark_for_compilation_flags->tf_xla_threshold_for_megamorphic, + "Sets the threshold for marking a cluster megamorphic. " + "Setting it to -1 disables marking clusters megamorphic." + "Setting it to 0 uses the default behaviour of TensorFlow."), Flag("tf_xla_sparse_core_disable_table_stacking", &sparse_core_flags->tf_xla_sparse_core_disable_table_stacking, "Disable table stacking for all the tables passed to the SparseCore" @@ -250,6 +255,7 @@ void AllocateAndParseFlags() { mark_for_compilation_flags->tf_xla_cpu_global_jit = false; mark_for_compilation_flags->tf_xla_clustering_fuel = std::numeric_limits::max(); + mark_for_compilation_flags->tf_xla_threshold_for_megamorphic = 0; mark_for_compilation_flags ->tf_xla_disable_deadness_safety_checks_for_debugging = false; mark_for_compilation_flags diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index ebb6b7a518146e..971dd8a7a38229 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -120,6 +120,11 @@ struct MarkForCompilationPassFlags { // Specifies the persistance cache prefix. Default is "xla_compile_cache" string tf_xla_persistent_cache_prefix; + + // Sets the threshold for marking a cluster megamorphic. + // Setting it to -1 disables marking clusters megamorphic. + // Setting it to 0 uses the default behaviour of TensorFlow. + int64_t tf_xla_threshold_for_megamorphic; }; // Flags associated with XLA Sparse Core. diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index dec057ebfba9ab..8da9744bacd79e 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -56,6 +56,7 @@ cc_library( "//tensorflow/compiler/jit:tf_graph_to_hlo_compiler", "//tensorflow/compiler/jit:tf_to_hlo_compiler", "//tensorflow/compiler/jit:xla_compile_util", + "//tensorflow/compiler/jit:xla_batch_matcher", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:refcount", "@com_google_absl//absl/base:core_headers", diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 2a8b1110639b49..c9c8df8a835f47 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -54,6 +54,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_host_send_device_context.h" #include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/jit/xla_platform_info.h" +#include "tensorflow/compiler/jit/xla_batch_matcher.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -65,13 +66,13 @@ limitations under the License. #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/batch_size_resource.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/kernels/batch_size_resource.h" #include "tensorflow/core/lib/monitoring/counter.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/errors.h" @@ -434,7 +435,8 @@ absl::Status CompileToLocalExecutable( // dimension, detecting dynamic dimension via _is_batch attr in the // argument. std::vector norm_args(args.begin(), args.end()); - constexpr int64_t kMagicBound = 977; + int64_t filled_batch = 0; + XlaBatchMatcher* xla_batch_matcher = xla_device_compiler->xla_batch_matcher(); if (options.flib_def != nullptr) { const FunctionDef* fdef = options.flib_def->Find(function.name()); if (fdef != nullptr) { @@ -446,23 +448,31 @@ absl::Status CompileToLocalExecutable( const AttrValue& v = it->second; if (it == attr_map.end()) continue; norm_args[arg_index].dynamic_dim = 0; + + if (!filled_batch && xla_batch_matcher) { + TensorShape& shp = std::get(norm_args[arg_index].shape); + filled_batch = xla_batch_matcher->get_xla_compile_batch(shp.dim_size(0)); + } } } } - for (int i = 0; i < norm_args.size(); ++i) { - auto& arg = norm_args[i]; - // argument rewrite. - if (arg.dynamic_dim == 0) { - TensorShape& shp = std::get(arg.shape); - int64_t old = shp.dim_size(0); - shp.set_dim(0, kMagicBound); - } - // constant argument rewrite otherwise it still store the incoming batch - // request. - if (arg.kind == XlaCompiler::Argument::kConstant) { - auto flat = arg.constant_value.flat(); - int32 old_batch = flat(0); - flat(0) = static_cast(kMagicBound); + + if (filled_batch) { + for (int i = 0; i < norm_args.size(); ++i) { + auto& arg = norm_args[i]; + // argument rewrite. + if (arg.dynamic_dim == 0) { + TensorShape& shp = std::get(arg.shape); + int64_t old = shp.dim_size(0); + shp.set_dim(0, filled_batch); + } + // constant argument rewrite otherwise it still store the incoming batch + // request. + if (arg.kind == XlaCompiler::Argument::kConstant) { + auto flat = arg.constant_value.flat(); + int32 old_batch = flat(0); + flat(0) = static_cast(filled_batch); + } } } @@ -1014,13 +1024,20 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { BatchSizeResource* bsr = nullptr; ScopedStepContainer* step_container = ctx->step_container(); - OP_REQUIRES_OK(ctx, step_container->Lookup( - ctx->resource_manager(), BatchSizeResourceName, &bsr)); + absl::Status st = step_container->Lookup( + ctx->resource_manager(), BatchSizeResourceName, &bsr); - run_options.set_batch_size(bsr->GetBatchSize()); - VLOG(1) << "run_options.batch_size is set to: " + if (st.ok()) { + run_options.set_batch_size(bsr->GetBatchSize()); + VLOG(1) << "run_options.batch_size is set to: " << run_options.batch_size() << ". step_id: " << ctx->step_id(); - bsr->Unref(); + bsr->Unref(); + + } else if (IsNotFound(st)) { + VLOG(1) << "Warning: Not found BatchSizeResource in step_container."; + } else { + OP_REQUIRES_OK(ctx, st); + } } // Host callbacks used for HLO send/recv. diff --git a/tensorflow/compiler/jit/xla_batch_matcher.cc b/tensorflow/compiler/jit/xla_batch_matcher.cc new file mode 100644 index 00000000000000..088eef4139855a --- /dev/null +++ b/tensorflow/compiler/jit/xla_batch_matcher.cc @@ -0,0 +1,185 @@ +#include "tensorflow/compiler/jit/xla_batch_matcher.h" +#include "xla/debug_options_flags.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +XlaBatchMatcher::XlaBatchMatcher() { + const std::string& xla_compile_batch_sizes = + xla::GetDebugOptionsFromFlags().xla_compile_batch_sizes(); + env_str_ = std::getenv(xla_compile_batch_sizes.c_str()); + parse_env_config(); +} + +// Trim whitespace (spaces/tabs) from both ends of a string +std::string trim(const std::string& s) { + size_t start = s.find_first_not_of(" \t"); + size_t end = s.find_last_not_of(" \t"); + return (start == std::string::npos) ? "" : s.substr(start, end - start + 1); +} + +std::vector XlaBatchMatcher::parse_single_item(const std::string& item) { + std::vector batch_list; + if (item.empty()) return batch_list; + + // 1. Parse single value (no colon separator) + if (item.find(':') == std::string::npos) { + // Validate all characters are digits (reject non-numeric chars) + for (char c : item) { + if (!isdigit(c)) { + throw std::invalid_argument("Non-numeric characters: " + item); + } + } + auto val = static_cast(std::stoll(item)); + if (val <= 0 || val > kMaxBatch) { + throw std::invalid_argument("Out of valid range (1-" + std::to_string(kMaxBatch) + "): " + item); + } + batch_list.push_back(val); + return batch_list; + } + + // 2. Parse range format (start:end:step) + std::stringstream ss(item); + std::string part; + std::vector parts; + while (std::getline(ss, part, ':')) { + parts.push_back(trim(part)); + } + if (parts.size() > 3) { + throw std::invalid_argument("Invalid range format (requires start:end:step): " + item); + } + + // Convert and validate range parameters + int64_t start, end, step; + try { + start = std::stoi(parts[0]); + end = std::stoi(parts[1]); + step = parts.size() == 2 ? 1 : std::stoi(parts[2]); + } catch (...) { + throw std::invalid_argument("Invalid numeric values in range: " + item); + } + if (start <= 0 || end <= 0 || step <= 0) { + throw std::invalid_argument("Range parameters must be positive integers: " + item); + } + if (start > end) { + throw std::invalid_argument("Start value > end value in range: " + item); + } + if (end > kMaxBatch) { + throw std::invalid_argument("Range exceeds max limit (" + std::to_string(kMaxBatch) + "): " + item); + } + + // Generate batch list from range + for (int64_t i = start; i <= end; i += step) { + batch_list.push_back(i); + } + return batch_list; +} + +void XlaBatchMatcher::print_all_batches() { + std::ostringstream oss; + oss << "[XLA_BATCH_INFO] Valid batch list update: "; + for (size_t i = 0; i < all_batches_.size(); ++i) { + if (i > 0) oss << ", "; + oss << all_batches_[i]; + } + LOG(INFO) << oss.str(); +} + +// Parse environment variable config into deduplicated, sorted batch list +// For example, export XLA_COMPILE_BATCH_SIZES="10:100:10, 977" +void XlaBatchMatcher::parse_env_config() { + // If the env var not set or is empty, filled with the nearest power of two by default + if (!env_str_) { + LOG(INFO) << "[XLA_BATCH_WARN] Env var " << "--tf_xla_compile_batch_sizes" << + " not set, filled with the nearest power of two by default"; + return; + } + + if (std::string(env_str_).empty()) { + LOG(INFO) << "[XLA_BATCH_WARN] Env var " << "--tf_xla_compile_batch_sizes" << + "is empty, filled with the nearest power of two by default"; + return; + } + + // Split config by commas + std::stringstream ss(env_str_); + std::string item; + while (std::getline(ss, item, ',')) { + std::string trimmed_item = trim(item); + if (trimmed_item.empty()) continue; + + // Parse single item (skip on failure to avoid breaking other items) + try { + std::vector item_batches = parse_single_item(trimmed_item); + all_batches_.insert(all_batches_.end(), item_batches.begin(), item_batches.end()); + } catch (const std::exception& e) { + LOG(INFO) << "[XLA_BATCH_WARN] Failed to parse config item, skipping: " << + trimmed_item << " (" << e.what() << ")"; + } + } + + if (!all_batches_.empty()) { + std::sort(all_batches_.begin(), all_batches_.end()); + auto last = std::unique(all_batches_.begin(), all_batches_.end()); + all_batches_.erase(last, all_batches_.end()); + } + + // Print parsed result + if (!all_batches_.empty()) print_all_batches(); + return; +} + +// Calculate the smallest power of two greater than the real batch +static int64_t GetNextPowerOfTwo(int64_t real_batch) { + // If real_batch is already a power of two, return real_batch directly + if ((real_batch & (real_batch - 1)) == 0) { + return real_batch; + } + + int64_t power = 1; + while (power < real_batch) { + power <<= 1; + } + return power; +} + +int64_t XlaBatchMatcher::find_min_larger_batch(int64_t real_batch) { + if (real_batch <= 0 || real_batch > kMaxBatch) { + LOG(INFO) << "[XLA_BATCH_WARN] Out of valid range: " << real_batch; + return real_batch; + } + + if (all_batches_.empty()) { + // Return the next power of two directly without modifying all_batches_ + return GetNextPowerOfTwo(real_batch); + } + + // Edge case 1: Real value < the smallest batch, use smallest + if (real_batch < all_batches_.front()) { + return all_batches_.front(); + } + // Edge case 2: Real value ≥ the largest batch, use the nearest power of two + if (real_batch > all_batches_.back()) { + int64_t val = GetNextPowerOfTwo(real_batch); + all_batches_.emplace_back(val); + print_all_batches(); + return val; + } + + // Find first batch larger than real value (binary search via lower_bound) + auto it = std::lower_bound(all_batches_.begin(), all_batches_.end(), real_batch); + return (it != all_batches_.end()) ? *it : all_batches_.back(); +} + +int64_t XlaBatchMatcher::get_xla_compile_batch(int64_t real_batch) { + // Match target batch size + int64_t selected = find_min_larger_batch(real_batch); + if (real_batch != last_batch_ || all_batches_.empty()) { + last_batch_ = real_batch; + LOG(INFO) << "[XLA_BATCH_INFO] Real batch: " << real_batch + << " -> Selected compile batch: " << selected; + } + return selected; +} + +} // namespace tensorflow \ No newline at end of file diff --git a/tensorflow/compiler/jit/xla_batch_matcher.h b/tensorflow/compiler/jit/xla_batch_matcher.h new file mode 100644 index 00000000000000..0ef9a2dd13afcd --- /dev/null +++ b/tensorflow/compiler/jit/xla_batch_matcher.h @@ -0,0 +1,40 @@ +#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_BATCH_MATCHER_H_ +#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_BATCH_MATCHER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tensorflow { + +// Define the maximum allowed batch size: 2147483648 >> 1 = 1073741824 +// Exceeding this value will make the next power of two exceed the safe range +constexpr int kMaxBatch = 2147483648ULL >> 1; + +class XlaBatchMatcher { + public: + XlaBatchMatcher(); + virtual ~XlaBatchMatcher() = default; + int64_t get_xla_compile_batch(int64_t real_batch); + std::vector get_all_batches() { return all_batches_; } + + private: + void parse_env_config(); + void print_all_batches(); + std::vector parse_single_item(const std::string& item); + int64_t find_min_larger_batch(int64_t real_batch); + + std::vector all_batches_; + const char* env_str_; + int64_t last_batch_ = -1; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_BATCH_MATCHER_H_ \ No newline at end of file diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 588161df309131..3f0e5df76b8c40 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -47,6 +47,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.h" #include "tensorflow/core/common_runtime/gpu_device_context.h" #include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/batch_size_resource.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" @@ -430,10 +431,32 @@ absl::Status XlaComputationLaunchContext::PopulateOutputs( } } else { for (int i = 0; i < ctx->num_outputs(); ++i) { - output_tensor_shapes.push_back(compilation_result->outputs[i].shape); + xla::Shape output_host_shape = output.on_host_shape(); + const xla::Shape& subshape = xla::ShapeUtil::GetSubshape(output_host_shape, {i}); + VLOG(2) << "PopulateOutputs: subshape[" << i << "]: "<< subshape; + TensorShape shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape)); + if (subshape.outer_multiplier() > 0) { + BatchSizeResource* bsr = nullptr; + ScopedStepContainer* step_container = ctx->step_container(); + TF_RETURN_IF_ERROR(step_container->Lookup( + ctx->resource_manager(), BatchSizeResourceName, &bsr)); + auto bsm = bsr->GetBatchSize() * subshape.outer_multiplier() ; + shape.set_dim(0, bsm); + output_tensor_shapes.push_back(shape); + bsr->Unref(); + } + else { + output_tensor_shapes.push_back(compilation_result->outputs[i].shape); + } } } + VLOG(2) << "output_tensor_shapes:"; + for (auto s:output_tensor_shapes) { + VLOG(2) << s; + } + // Copy XLA results to the OpOutputList. int output_num = 0; for (int i = 0, end = ctx->num_outputs(); i < end; ++i) { diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index 09142e303e3e13..711c148aec958e 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -62,6 +62,7 @@ package( exports_files( srcs = [ "allocator_registry.h", + "batch_size_resource.h", "collective.h", "control_flow.h", "dataset.h", @@ -195,6 +196,7 @@ filegroup( "allocator.h", "allocator_registry.h", "attr_value_util.h", + "batch_size_resource.h", "bfloat16.h", "bounds_check.h", "cancellation.h", diff --git a/tensorflow/core/kernels/batch_size_resource.h b/tensorflow/core/framework/batch_size_resource.h similarity index 100% rename from tensorflow/core/kernels/batch_size_resource.h rename to tensorflow/core/framework/batch_size_resource.h diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index a5a48347f07517..c468801a6a74db 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -741,6 +741,16 @@ bool IsBiasSemanticAdd(const RemapperContext& ctx, return false; } +bool MaybeCopyOutputShapesAttr(const NodeDef& from, NodeDef* fused_op) { + const string output_shape_attr_name = "_output_shapes"; + if (from.attr().count(output_shape_attr_name) > 0) { + auto shape_attrs = from.attr().at(output_shape_attr_name); + AddNodeAttr(output_shape_attr_name, shape_attrs, fused_op); + return true; + } + return false; +} + void AddInputShapesAttr(const RemapperContext& ctx, int node_index) { auto mutable_node = ctx.graph_view.graph()->mutable_node(node_index); @@ -3334,6 +3344,7 @@ absl::Status AddFusedContractionNode(RemapperContext* ctx, } SetFusedOpAttributes(&fused_op, {"BiasAdd"}); + MaybeCopyOutputShapesAttr(bias_add, &fused_op); utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); absl::Status status; mutation->AddNode(std::move(fused_op), &status); @@ -3439,6 +3450,7 @@ absl::Status AddFusedContractionNode( } SetFusedOpAttributes(&fused_op, {"BiasAdd", activation.op()}); + MaybeCopyOutputShapesAttr(activation, &fused_op); utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); absl::Status status; @@ -3630,6 +3642,7 @@ absl::Status AddFusedContractionNode( } SetFusedOpAttributes(&contraction_node, {"BiasAdd", "Add"}, 2); + MaybeCopyOutputShapesAttr(add, &contraction_node); utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); absl::Status status; @@ -3729,6 +3742,7 @@ absl::Status AddFusedContractionNode( } SetFusedOpAttributes(&fused_conv, {"BiasAdd", "Add", activation.op()}, 2); + MaybeCopyOutputShapesAttr(add, &fused_conv); utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); absl::Status status; diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 8b71e9a4c0ad7c..d25f6736317f61 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -3072,7 +3072,6 @@ tf_kernel_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/profiler/lib:traceme", - "//tensorflow/core/kernels:batch_size_resource", ], ) @@ -7973,15 +7972,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "batch_size_resource", - srcs = ["batch_size_resource.h"], - deps = [ - "//tensorflow/core:framework", - ], -) - - # For a more maintainable build this target should not exist and the headers # should be split into the existing cc_library targets, but this change was # automatically done so that we can remove long standing issues and complexity diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index 6ea77f78f04797..eec275448c50ce 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -24,13 +24,13 @@ limitations under the License. #include "tensorflow/core/common_runtime/gradients.h" #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/memory_types.h" +#include "tensorflow/core/framework/batch_size_resource.h" #include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/full_type.pb.h" #include "tensorflow/core/framework/full_type_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/graph/algorithm.h" -#include "tensorflow/core/kernels/batch_size_resource.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/util/device_name_utils.h" @@ -43,6 +43,13 @@ static constexpr const char* const kGradientOp = ArgOp::ArgOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_)); + + Status s = ctx->GetAttr("_is_batch", &is_batch_); + if (IsNotFound(s)) { + is_batch_ = false; + } else { + OP_REQUIRES_OK(ctx, s); + } } void ArgOp::Compute(OpKernelContext* ctx) { diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 33fa90f7e35e9e..961273151ed5e0 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -103,6 +103,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_cpu_use_fusion_emitters(true); opts.set_xla_cpu_use_thunk_runtime(true); opts.set_xla_cpu_use_xnnpack(false); + opts.set_xla_compile_batch_sizes(""); opts.set_xla_cpu_experimental_xnn_graph_fusion_mode( DebugOptions::XNN_GRAPH_FUSION_MODE_DISABLED); opts.set_xla_cpu_parallel_codegen_split_count(32); @@ -1003,6 +1004,15 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "`XNN_GRAPH_FUSION_MODE_DISABLED` - default value, " "`XNN_GRAPH_FUSION_MODE_GREEDY` - greedy extraction of " "XNNPACK-compatible subgraphs starting from root instructions.")); + flag_list->push_back(tsl::Flag( + "xla_compile_batch_sizes", + string_setter_for( + &DebugOptions::set_xla_compile_batch_sizes), + debug_options->xla_compile_batch_sizes(), + "Comma-separated list of batch sizes to use for compilation, " + "use single value or start:end:step format. " + "e.g. 32, 64, 128, 10:100:10, " + "empty to use the nearest power of two.")); flag_list->push_back(tsl::Flag( "xla_cpu_parallel_codegen_split_count", int32_setter_for(&DebugOptions::set_xla_cpu_parallel_codegen_split_count), diff --git a/third_party/xla/xla/service/cpu/cpu_executable.cc b/third_party/xla/xla/service/cpu/cpu_executable.cc index e161f760a5da03..a215511c5aba5e 100644 --- a/third_party/xla/xla/service/cpu/cpu_executable.cc +++ b/third_party/xla/xla/service/cpu/cpu_executable.cc @@ -377,13 +377,15 @@ absl::StatusOr CpuExecutable::CreateResultShapedBuffer( absl::Span buffers, absl::Span arguments) { se::Stream* stream = run_options->stream(); - ExecutionOutput result(/*on_device_shape=*/result_shape(), - run_options->allocator(), - stream->parent()->device_ordinal()); const HloInputOutputAliasConfig& input_output_alias = module().input_output_alias_config(); HloInstruction* root = hlo_module_->entry_computation()->root_instruction(); const Shape& root_shape = root->shape(); + // Use root_shape to initialize ExecutionOuput as the batch multiplier info + // is only attached the ROOT + ExecutionOutput result(/*on_device_shape=*/root_shape, //result_shape(), + run_options->allocator(), + stream->parent()->device_ordinal()); // Move se::OwningDeviceMemory values which contain the array(s) of the result // into the respective location in ScopedShapedBuffer which is returned to the diff --git a/third_party/xla/xla/service/layout_assignment.cc b/third_party/xla/xla/service/layout_assignment.cc index b7bea1e55704c5..b5adc212b53a17 100644 --- a/third_party/xla/xla/service/layout_assignment.cc +++ b/third_party/xla/xla/service/layout_assignment.cc @@ -2571,7 +2571,7 @@ absl::Status LayoutAssignment::PropagateComputationLayouts( *result_layout = computed_computation_layout.result_layout(); } else { TF_RET_CHECK( - Shape::Equal().IgnoreDynamicDimension().MinorToMajorOnlyInLayout()( + Shape::Equal().IgnoreDynamicDimension().MinorToMajorOnlyInLayout().IgnoreBatch()( computed_computation_layout.result_layout().shape(), result_layout->shape())); } diff --git a/third_party/xla/xla/shape.cc b/third_party/xla/xla/shape.cc index 1cee38146fb07d..dbfe4d185dc904 100644 --- a/third_party/xla/xla/shape.cc +++ b/third_party/xla/xla/shape.cc @@ -136,6 +136,7 @@ absl::StatusOr Shape::FromProto(const ShapeProto& shape_proto) { TF_ASSIGN_OR_RETURN(*shape.mutable_layout(), Layout::FromProto(shape_proto.layout())); } + shape.set_outer_multiplier(shape_proto.batch_multiplier()); return shape; } @@ -163,6 +164,7 @@ ShapeProto Shape::ToProto() const { proto.mutable_tuple_shapes()->Reserve(1); *proto.add_tuple_shapes() = state->buffer_shape[0].ToProto(); } + proto.set_batch_multiplier(outer_multiplier()); return proto; } @@ -509,6 +511,10 @@ bool Shape::Equal::operator()(const Shape& lhs, const Shape& rhs) { rhs.is_unbounded_dynamic_dimension(i))) { continue; } + if (i == 0 && ignore_batch_ && (lhs.outer_multiplier() > 0 || rhs.outer_multiplier() > 0)) { + VLOG(3) << "CompareShapes: batch dimension found. Forcely compatible"; + continue; + } if (lhs.dimensions(i) != rhs.dimensions(i)) { VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions"; return false; diff --git a/third_party/xla/xla/shape.h b/third_party/xla/xla/shape.h index d9dc14ad4deb41..a53297d18ec6fc 100644 --- a/third_party/xla/xla/shape.h +++ b/third_party/xla/xla/shape.h @@ -434,6 +434,10 @@ class Shape { bool operator()(const Shape& lhs, const Shape& rhs); + Equal& IgnoreBatch(bool ignore_batch = true) { + ignore_batch_ = ignore_batch; + return *this; + } Equal& IgnoreLayout(bool ignore_layout = true) { ignore_layout_ = ignore_layout; return *this; @@ -488,6 +492,7 @@ class Shape { } private: + bool ignore_batch_ = false; bool ignore_layout_ = false; bool ignore_tiles_in_layout_ = false; bool ignore_element_size_in_layout_ = false; diff --git a/third_party/xla/xla/shape_util.cc b/third_party/xla/xla/shape_util.cc index 4f92ce19adb1f9..e0a327e29e57da 100644 --- a/third_party/xla/xla/shape_util.cc +++ b/third_party/xla/xla/shape_util.cc @@ -753,6 +753,11 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { return; } PrintHumanString(printer, shape); + if (shape.outer_multiplier() > 0) { + printer->Append("(bm="); + printer->Append(shape.outer_multiplier()); + printer->Append(")"); + } if (!shape.IsArray()) return; if (!shape.has_layout()) return; if (IsScalar(shape)) { @@ -811,6 +816,10 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { const Shape& rhs) { if (!SameRank(lhs, rhs)) return false; for (int i = 0; i < lhs.dimensions().size(); ++i) { + if (i == 0 && (lhs.outer_multiplier() > 0 || rhs.outer_multiplier() > 0)) { + VLOG(3) << "CompareShapes: batch dimension found. Forcely compatible"; + continue; + } if (!lhs.is_unbounded_dynamic_dimension(i) && !rhs.is_unbounded_dynamic_dimension(i) && lhs.dimensions(i) != rhs.dimensions(i)) { @@ -826,7 +835,7 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { } /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) { - return Shape::Equal().IgnoreDynamicDimension().IgnoreLayout()(lhs, rhs); + return Shape::Equal().IgnoreDynamicDimension().IgnoreLayout().IgnoreBatch()(lhs, rhs); } /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs, diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index ca8ba0553bd56a..45c87a25d529d4 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -222,6 +222,8 @@ message DebugOptions { // When true, XLA:CPU uses XNNPACK to execute supported operations. bool xla_cpu_use_xnnpack = 359; + string xla_compile_batch_sizes = 389; + // Enabling this will enable optimizations that ignore the possibility of NaN. bool xla_enable_fast_math = 335; @@ -1208,7 +1210,7 @@ message DebugOptions { // Note: when adding a new flag, please add it to one of the hardware-specific // or hardware-agnostic sections at the top of this proto message. - // Next id: 389 + // Next id: 390 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. diff --git a/third_party/xla/xla/xla_data.proto b/third_party/xla/xla/xla_data.proto index 98462a4f6eb83c..85ad79ae841078 100644 --- a/third_party/xla/xla/xla_data.proto +++ b/third_party/xla/xla/xla_data.proto @@ -361,6 +361,8 @@ message ShapeProto { // The layout used to back this shape. LayoutProto layout = 5; + int64 batch_multiplier = 7; + // Important: if any field is added, be sure to modify ShapeUtil::Equal(), // ShapeUtil::Compatible() and Shape::Hash() appropriately to account for the // new field. From 09ff0a6f26298158ec29fe318879a50e6a978c4b Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Sat, 21 Mar 2026 11:32:58 +0000 Subject: [PATCH 05/16] Add symbolic expressions to TensorFlow and XLA shapes Introduces symbolic dimension expressions to TensorFlow and XLA shape representations so dynamic dimensions can be tracked structurally instead of only by concrete values. Co-authored-by: Steven Varoumas Co-authored-by: Utku Saglam Co-authored-by: Jinyun (Joey) Ye Co-authored-by: Muteages <67578152+Muteages@users.noreply.github.com> Co-authored-by: Guillermo Callaghan --- .../jit/encapsulate_subgraphs_pass.cc | 142 ++++++- tensorflow/compiler/jit/kernels/xla_ops.cc | 332 ++++++++++++++--- .../compiler/jit/mark_for_compilation_pass.cc | 4 +- tensorflow/compiler/jit/shape_inference.cc | 9 +- tensorflow/compiler/jit/xla_launch_util.cc | 29 +- tensorflow/compiler/tf2xla/kernels/beta_op.cc | 9 +- .../compiler/tf2xla/kernels/bincount_op.cc | 11 +- .../tf2xla/kernels/broadcast_to_op.cc | 3 +- .../tf2xla/kernels/clip_by_value_op.cc | 4 +- .../compiler/tf2xla/kernels/const_op.cc | 3 +- .../tf2xla/kernels/conv_op_helpers.cc | 13 +- .../compiler/tf2xla/kernels/conv_ops.cc | 9 +- tensorflow/compiler/tf2xla/kernels/diag_op.cc | 2 +- .../tf2xla/kernels/dynamic_partition_op.cc | 5 +- .../tf2xla/kernels/fake_quantize_ops.cc | 19 +- .../compiler/tf2xla/kernels/gather_op.cc | 3 +- .../compiler/tf2xla/kernels/image_ops.cc | 2 +- .../compiler/tf2xla/kernels/in_topk_op.cc | 6 +- .../tf2xla/kernels/lower_upper_bound_ops.cc | 6 +- .../tf2xla/kernels/matrix_diag_ops.cc | 9 +- .../kernels/matrix_triangular_solve_op.cc | 6 +- tensorflow/compiler/tf2xla/kernels/pack_op.cc | 10 +- .../compiler/tf2xla/kernels/pooling_ops.cc | 8 +- .../kernels/quantize_and_dequantize_op.cc | 5 +- .../tf2xla/kernels/reduction_ops_common.cc | 6 +- tensorflow/compiler/tf2xla/kernels/relu_op.cc | 32 +- .../compiler/tf2xla/kernels/reshape_op.cc | 59 ++- .../compiler/tf2xla/kernels/scatter_nd_op.cc | 3 +- .../tf2xla/kernels/segment_reduction_ops.cc | 3 +- .../compiler/tf2xla/kernels/select_op.cc | 7 +- .../compiler/tf2xla/kernels/shape_op.cc | 17 +- .../compiler/tf2xla/kernels/slice_op.cc | 16 +- .../compiler/tf2xla/kernels/softmax_op.cc | 65 +++- .../tf2xla/kernels/sparse_to_dense_op.cc | 3 +- .../compiler/tf2xla/kernels/split_op.cc | 11 +- .../compiler/tf2xla/kernels/stack_ops.cc | 3 +- .../tf2xla/kernels/stateless_random_ops.cc | 17 +- .../tf2xla/kernels/strided_slice_op.cc | 102 +++-- .../tf2xla/kernels/tensor_array_ops.cc | 20 +- .../tf2xla/kernels/tensor_list_ops.cc | 16 +- .../tf2xla/kernels/tensor_list_utils.cc | 32 +- .../compiler/tf2xla/kernels/tile_ops.cc | 10 +- .../compiler/tf2xla/kernels/unique_op.cc | 41 +- .../compiler/tf2xla/kernels/unpack_op.cc | 3 +- .../compiler/tf2xla/kernels/where_op.cc | 27 +- tensorflow/compiler/tf2xla/layout_util.cc | 1 + tensorflow/compiler/tf2xla/lib/broadcast.cc | 7 +- tensorflow/compiler/tf2xla/lib/broadcast.h | 5 +- tensorflow/compiler/tf2xla/lib/data_format.cc | 9 +- tensorflow/compiler/tf2xla/ops/xla_ops.cc | 8 +- tensorflow/compiler/tf2xla/shape_util.cc | 19 +- tensorflow/compiler/tf2xla/xla_argument.h | 4 +- tensorflow/compiler/tf2xla/xla_compiler.cc | 30 +- tensorflow/compiler/tf2xla/xla_op_kernel.cc | 3 +- tensorflow/core/BUILD | 1 + .../core/common_runtime/constant_folding.cc | 2 +- tensorflow/core/framework/BUILD | 20 + tensorflow/core/framework/common_shape_fns.cc | 8 +- tensorflow/core/framework/shape_inference.cc | 235 +++++++++--- tensorflow/core/framework/shape_inference.h | 47 ++- tensorflow/core/framework/tensor_shape.cc | 127 +++++++ tensorflow/core/framework/tensor_shape.h | 36 ++ tensorflow/core/framework/tensor_shape.proto | 38 ++ .../core/framework/tensor_shape_expr.cc | 196 ++++++++++ tensorflow/core/framework/tensor_shape_expr.h | 219 +++++++++++ .../core/grappler/costs/graph_properties.cc | 300 ++++++++++++++- tensorflow/core/kernels/padding_fifo_queue.cc | 4 + tensorflow/core/kernels/strided_slice_op.cc | 4 +- tensorflow/core/ops/array_ops.cc | 10 + tensorflow/core/util/strided_slice_op.cc | 101 ++++- tensorflow/core/util/strided_slice_op.h | 5 + .../cpu/codegen/kernel_api_ir_builder.cc | 2 +- .../xla/xla/hlo/builder/lib/approx_topk.cc | 8 +- .../xla/xla/hlo/builder/lib/arithmetic.cc | 5 +- .../xla/xla/hlo/builder/lib/broadcast.cc | 31 +- .../xla/xla/hlo/builder/lib/broadcast.h | 5 +- third_party/xla/xla/hlo/builder/lib/matrix.cc | 15 +- third_party/xla/xla/hlo/builder/lib/prng.cc | 20 +- .../xla/xla/hlo/builder/lib/slicing.cc | 44 ++- third_party/xla/xla/hlo/builder/lib/svd.cc | 23 +- .../xla/xla/hlo/builder/xla_builder.cc | 211 ++++++++--- third_party/xla/xla/hlo/builder/xla_builder.h | 84 ++++- .../xla/xla/hlo/pass/hlo_pass_pipeline.cc | 1 + .../collectives/all_gather_combiner.cc | 4 +- .../expanders/bitcast_dtypes_expander.cc | 11 +- .../transforms/expanders/cholesky_expander.cc | 6 +- .../transforms/expanders/dot_decomposer.cc | 60 ++- .../hlo/transforms/expanders/eigh_expander.cc | 15 +- .../hlo/transforms/expanders/qr_expander.cc | 24 +- .../expanders/rng_bit_generator_expander.cc | 6 +- .../translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc | 18 +- .../translate/mhlo_to_hlo/type_to_shape.cc | 4 +- .../xla/xla/service/cpu/dot_op_emitter.cc | 66 +++- third_party/xla/xla/service/cpu/ir_emitter.cc | 117 ++++-- third_party/xla/xla/service/cpu/ir_emitter.h | 4 +- .../xla/xla/service/cpu/ir_emitter2.cc | 10 +- .../xla/service/cpu/parallel_loop_emitter.cc | 6 +- .../xla/xla/service/elemental_ir_emitter.cc | 56 +-- .../xla/xla/service/hlo_creation_utils.cc | 15 +- .../xla/xla/service/llvm_ir/ir_array.cc | 55 +-- .../xla/xla/service/llvm_ir/llvm_loop.cc | 53 ++- .../xla/xla/service/llvm_ir/llvm_loop.h | 8 +- .../xla/xla/service/llvm_ir/llvm_util.cc | 85 ++++- .../xla/xla/service/llvm_ir/llvm_util.h | 13 +- .../xla/xla/service/llvm_ir/loop_emitter.cc | 6 +- .../xla/xla/service/llvm_ir/tuple_ops.cc | 2 +- .../xla/service/reduce_scatter_combiner.cc | 4 +- .../xla/xla/service/shape_inference.cc | 189 ++++++++-- third_party/xla/xla/service/shape_inference.h | 16 +- .../xla/service/triangular_solve_expander.cc | 9 +- third_party/xla/xla/shape.cc | 349 +++++++++++++++++- third_party/xla/xla/shape.h | 289 ++++++++++++++- third_party/xla/xla/shape_util.cc | 75 +++- third_party/xla/xla/shape_util.h | 20 +- .../xla/xla/stream_executor/tpu/c_api_decl.h | 2 + third_party/xla/xla/xla_data.proto | 33 +- 116 files changed, 3964 insertions(+), 696 deletions(-) create mode 100644 tensorflow/core/framework/tensor_shape_expr.cc create mode 100644 tensorflow/core/framework/tensor_shape_expr.h diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 1bcae2384a8058..0662976897378b 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -54,11 +54,12 @@ limitations under the License. #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/dump_graph.h" - +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/grappler_item.h" namespace tensorflow { static const absl::flat_hash_set kFailingOps = { - "Pad", "Where", // add more here }; @@ -120,6 +121,133 @@ void MarkGuaranteedConstants( } } +// Helper to convert ExpressionProto to a readable string. +std::string ExprProtoToString(const ExpressionProto& e) { + switch (e.node_type_case()) { + case ExpressionProto::kConstantValue: + return std::to_string(e.constant_value()); + case ExpressionProto::kVariableId: + return absl::StrCat("Var(", e.variable_id(), ")"); + case ExpressionProto::kAddNode: + return absl::StrCat("(", ExprProtoToString(e.add_node().lhs()), " + ", + ExprProtoToString(e.add_node().rhs()), ")"); + case ExpressionProto::kSubNode: + return absl::StrCat("(", ExprProtoToString(e.sub_node().lhs()), " - ", + ExprProtoToString(e.sub_node().rhs()), ")"); + case ExpressionProto::kMulNode: + return absl::StrCat("(", ExprProtoToString(e.mul_node().lhs()), " * ", + ExprProtoToString(e.mul_node().rhs()), ")"); + case ExpressionProto::kDivNode: + return absl::StrCat("(", ExprProtoToString(e.div_node().lhs()), " / ", + ExprProtoToString(e.div_node().rhs()), ")"); + default: + return ""; + } +} + +std::map>> test_map; + +std::unique_ptr ExprFromProto(const ExpressionProto& proto) { + switch (proto.node_type_case()) { + case ExpressionProto::kConstantValue: + return DynExpr::Cons(proto.constant_value()); + case ExpressionProto::kVariableId: + return DynExpr::Var(proto.variable_id()); + case ExpressionProto::kAddNode: { + auto lhs = ExprFromProto(proto.add_node().lhs()); + auto rhs = ExprFromProto(proto.add_node().rhs()); + // Note: These are owning pointers, but ExprAdd takes raw pointers. + // The caller must manage lifetime appropriately. + return std::make_unique(lhs.release(), rhs.release()); + } + case ExpressionProto::kSubNode: { + auto lhs = ExprFromProto(proto.sub_node().lhs()); + auto rhs = ExprFromProto(proto.sub_node().rhs()); + return std::make_unique(lhs.release(), rhs.release()); + } + case ExpressionProto::kMulNode: { + auto lhs = ExprFromProto(proto.mul_node().lhs()); + auto rhs = ExprFromProto(proto.mul_node().rhs()); + return std::make_unique(lhs.release(), rhs.release()); + } + case ExpressionProto::kDivNode: { + auto lhs = ExprFromProto(proto.div_node().lhs()); + auto rhs = ExprFromProto(proto.div_node().rhs()); + return std::make_unique(lhs.release(), rhs.release()); + } + case ExpressionProto::NODE_TYPE_NOT_SET: + default: + return nullptr; + } +} + +// Runs Grappler static inference and logs any ExpressionProto found in output +// tensor shapes (from GraphProperties, not from _output_shapes attrs). +void LogExpressionsViaGraphProperties(const tensorflow::Graph& graph) { + using tensorflow::ExpressionProto; + using tensorflow::GraphDef; + using tensorflow::NodeDef; + using tensorflow::TensorShapeProto; + using tensorflow::grappler::GraphProperties; + using tensorflow::grappler::GrapplerItem; + + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + + GrapplerItem item; + item.id = "mark_for_compilation_pass_expr_dump"; + item.graph = graph_def; + + GraphProperties props(item); + + absl::Status st = props.InferStatically( + /*assume_valid_feeds=*/false, + /*aggressive_shape_inference=*/false, + /*include_input_tensor_values=*/false, + /*include_output_tensor_values=*/false); + + if (!st.ok()) { + LOG(ERROR) << "[EXPR][GP] InferStatically failed: " << st.message(); + return; + } + + int found = 0; + VLOG(1) << "[EXPR][GP] === GraphProperties output expr dump ==="; + + + for (const NodeDef& n : graph_def.node()) { + if (!props.HasOutputProperties(n.name())) continue; + const auto& outs = props.GetOutputProperties(n.name()); + for (int out_idx = 0; out_idx < static_cast(outs.size()); ++out_idx) { + const auto& tp = outs[out_idx]; + const TensorShapeProto& shp = tp.shape(); + + if (shp.unknown_rank()) continue; + std::vector> exprs; + for (int d = 0; d < shp.dim_size(); ++d) { + const auto& dim = shp.dim(d); + + const ExpressionProto& expr = dim.expr(); + if (expr.node_type_case() == ExpressionProto::NODE_TYPE_NOT_SET) + continue; + + VLOG(1) << "Node " << n.name() << " has expression " + << ExprProtoToString(expr); + + auto ex = ExprFromProto(expr); + exprs.push_back(std::move(ex)); + + ++found; + } + test_map[n.name()] = std::move(exprs); + } + } + + VLOG(1) << "[EXPR][GP] === Found " << found + << " expressions via GraphProperties ==="; +} + + struct OutputInputTensorPairHasher { uint64 operator()(std::pair const& s) const { return Hash64Combine(OutputTensor::Hash()(s.first), @@ -480,6 +608,13 @@ absl::Status Encapsulator::Subgraph::RecordArg( auto shape_attr = attrs.FindByString("_output_shapes"); if (shape_attr && shape_attr->has_list()) { const TensorShapeProto& shape = shape_attr->list().shape(src_slot); + std::vector> expressions = + std::move(test_map[src_node->name()]); + for (const auto& e : expressions) { + if (!e->IsConstant()) { + builder.Attr("_is_batch", true); + } + } if (shape.dim_size() >= 1 && shape.dim(0).size() == -1) { VLOG(1) << "Found Dynamic dimension in " << src_node->name() << ":" << src_slot; @@ -1183,6 +1318,9 @@ absl::Status EncapsulateSubgraphsPass::Run( options.flib_def); } + LogExpressionsViaGraphProperties(**options.graph); + + // TODO(b/195757077): Remove this once there is a better way to disable // GraphOptimizationPasses that are not needed due to MLIR bridge. for (Node* n : (*options.graph)->nodes()) { diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index c9c8df8a835f47..dd6a2d41053b3b 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -42,6 +42,7 @@ limitations under the License. #include "tensorflow/compiler/jit/device_compilation_profiler.h" #include "tensorflow/compiler/jit/device_compiler.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/jit/encapsulate_util.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/pjrt_compile_util.h" #include "tensorflow/compiler/jit/variable_info.h" @@ -381,6 +382,72 @@ GetXlaCompilerArgsAndSnapshotVariables( return result; } + +std::unique_ptr ExprFromProto(const ExpressionProto& proto) { + switch (proto.node_type_case()) { + case ExpressionProto::kConstantValue: + return DimExpr::Cons(proto.constant_value()); + case ExpressionProto::kVariableId: + return DimExpr::Var(proto.variable_id()); + case ExpressionProto::kAddNode: { + auto lhs = ExprFromProto(proto.add_node().lhs()); + auto rhs = ExprFromProto(proto.add_node().rhs()); + // Note: These are owning pointers, but ExprAdd takes raw pointers. + // The caller must manage lifetime appropriately. + return std::make_unique(lhs.release(), rhs.release()); + } + case ExpressionProto::kSubNode: { + auto lhs = ExprFromProto(proto.sub_node().lhs()); + auto rhs = ExprFromProto(proto.sub_node().rhs()); + return std::make_unique(lhs.release(), rhs.release()); + } + case ExpressionProto::kMulNode: { + auto lhs = ExprFromProto(proto.mul_node().lhs()); + auto rhs = ExprFromProto(proto.mul_node().rhs()); + return std::make_unique(lhs.release(), rhs.release()); + } + case ExpressionProto::kDivNode: { + auto lhs = ExprFromProto(proto.div_node().lhs()); + auto rhs = ExprFromProto(proto.div_node().rhs()); + return std::make_unique(lhs.release(), rhs.release()); + } + case ExpressionProto::NODE_TYPE_NOT_SET: + default: + return nullptr; + } +} + +static xla::DynExpr* DimExprToDynExpr(const DimExpr* e) { + switch (e->kind()) { + case DimExpr::Kind::kConstant: { + auto* ac = static_cast(e); + return xla::DynExpr::_(ac->value()); + } + case DimExpr::Kind::kVariable: { + auto* av = static_cast(e); + return xla::DynExpr::V(1); + } + case DimExpr::Kind::kAdd: { + auto* ee = static_cast(e); + return *DimExprToDynExpr(ee->lhs()) + *DimExprToDynExpr(ee->rhs()); + } + case DimExpr::Kind::kSub: { + auto* ee = static_cast(e); + return *DimExprToDynExpr(ee->lhs()) - *DimExprToDynExpr(ee->rhs()); + } + case DimExpr::Kind::kMul: { + auto* ee = static_cast(e); + return *DimExprToDynExpr(ee->lhs()) * *DimExprToDynExpr(ee->rhs()); + } + case DimExpr::Kind::kDiv: { + auto* ee = static_cast(e); + return *DimExprToDynExpr(ee->lhs()) / *DimExprToDynExpr(ee->rhs()); + } + } + return nullptr; +} + + absl::Status CompileToLocalExecutable( OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars, const XlaPlatformInfo& platform_info, @@ -431,54 +498,193 @@ absl::Status CompileToLocalExecutable( MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); if (flags->tf_xla_enable_dynamic_sizes) { - // Rewriting the argument with the magic number if they have dynamic - // dimension, detecting dynamic dimension via _is_batch attr in the - // argument. + // Rewriting the argument with expressions if they have dynamic + // dimension, detecting dynamic dimension via either _dynamic_dim or the + // inferred-output-shapes attr attached during encapsulation. std::vector norm_args(args.begin(), args.end()); int64_t filled_batch = 0; - XlaBatchMatcher* xla_batch_matcher = xla_device_compiler->xla_batch_matcher(); + bool saw_dynamic_dim_value = false; + // Only supporting one dynamic dimension. + bool has_multiple_dynamic_dim_values = false; + int64_t dynamic_dim_value = 0; + XlaBatchMatcher* xla_batch_matcher = + xla_device_compiler->xla_batch_matcher(); + auto record_dynamic_dim_value = [&](int64_t dim_size) { + if (!saw_dynamic_dim_value) { + saw_dynamic_dim_value = true; + dynamic_dim_value = dim_size; + return; + } + if (dynamic_dim_value != dim_size) { + has_multiple_dynamic_dim_values = true; + } + }; if (options.flib_def != nullptr) { const FunctionDef* fdef = options.flib_def->Find(function.name()); if (fdef != nullptr) { for (const auto& kv : fdef->arg_attr()) { int arg_index = kv.first; const auto& attr_map = kv.second.attr(); - auto it = attr_map.find("_is_batch"); - - const AttrValue& v = it->second; + const std::string& node_name = + fdef->signature().input_arg(arg_index).name(); + + // Special case for _dynamic_dim... + auto dyn_dim_attr = attr_map.find("_dynamic_dim"); + if (dyn_dim_attr != attr_map.end()) { + TensorShape& shp = + std::get(norm_args[arg_index].shape); + const AttrValue& v = dyn_dim_attr->second; + int64_t idx = v.i(); + record_dynamic_dim_value(shp.dim_size(idx)); + if (!filled_batch && xla_batch_matcher) { + filled_batch = + xla_batch_matcher->get_xla_compile_batch(shp.dim_size(idx)); + } + + std::vector dyn_exprs; + for (int d : shp.dim_sizes()) { + dyn_exprs.push_back(xla::DynExpr::_(d)); + } + dyn_exprs[idx] = xla::DynExpr::V(1); + shp.set_expressions(dyn_exprs); + continue; + } + auto it = attr_map.find(kXlaInferredOutputShapesAttrName); if (it == attr_map.end()) continue; - norm_args[arg_index].dynamic_dim = 0; + + const TensorShapeProto& proto = it->second.list().shape(0); + const auto& exp = proto.expressions(); + TensorShape& shp = std::get(norm_args[arg_index].shape); if (!filled_batch && xla_batch_matcher) { - TensorShape& shp = std::get(norm_args[arg_index].shape); - filled_batch = xla_batch_matcher->get_xla_compile_batch(shp.dim_size(0)); + for (int idx = 0; idx < exp.size(); ++idx) { + // Look for dynamic expression. If found then compute padding + // value and exit loop. + auto e = DimExprToDynExpr(ExprFromProto(exp[idx]).get())->s(); + if (e->is_dynamic()) { + int64_t var_value = e->solve(shp.dim_size(idx)); + if (var_value <= 0) { + LOG(WARNING) + << "Failed to solve dynamic dimension for argument " + << arg_index << " dim " << idx << " with size " + << shp.dim_size(idx) + << "; falling back to original dimension size."; + var_value = shp.dim_size(idx); + } else { + VLOG(1) << "Solved dynamic dimension from " + << shp.dim_size(idx) << " to " << var_value; + } + record_dynamic_dim_value(var_value); + filled_batch = + xla_batch_matcher->get_xla_compile_batch(var_value); + break; + } + } + } + + std::vector dyn_exprs; + for (int d : shp.dim_sizes()) { + dyn_exprs.push_back(xla::DynExpr::_(d)); } + for (int j = 0; j < exp.size(); ++j) { + auto e = DimExprToDynExpr(ExprFromProto(exp[j]).get())->s(); + if (e->is_dynamic()) { + dyn_exprs[j] = e; + } + } + shp.set_expressions(dyn_exprs); } } } + struct SaveOldVar { + int arg_index; + int64_t dyn_dim; + int64_t old_value; + }; + std::vector old_vars; + auto maybe_rewrite_scalar_constant = [&](int arg_index) { + if (!saw_dynamic_dim_value || has_multiple_dynamic_dim_values) { + return; + } + + auto& arg = norm_args[arg_index]; + if (arg.kind != XlaCompiler::Argument::kConstant) { + return; + } + + const bool is_scalar = TensorShapeUtils::IsScalar(arg.constant_value.shape()); + const bool is_vector = TensorShapeUtils::IsVector(arg.constant_value.shape()) && + arg.constant_value.NumElements() > 0; + if (!is_scalar && !is_vector) { + return; + } + + if (arg.constant_value.dtype() == DT_INT32) { + const int32 old_value = arg.constant_value.flat()(0); + // Heuristic: rewrite only scalar constants or shape-like int vectors + // whose leading entry matches the observed runtime batch size. + if (old_value == dynamic_dim_value) { + // Deep-copy before rewrite so the compile-time patch does not mutate + // a Tensor buffer shared with caller-visible inputs. + Tensor scalar_copy(arg.constant_value.dtype(), + arg.constant_value.shape()); + scalar_copy.flat() = arg.constant_value.flat(); + arg.constant_value = std::move(scalar_copy); + arg.constant_value.flat()(0) = + static_cast(filled_batch); + } + } else if (arg.constant_value.dtype() == DT_INT64) { + const int64_t old_value = arg.constant_value.flat()(0); + // Same heuristic for int64 scalar constants. + if (old_value == dynamic_dim_value) { + Tensor scalar_copy(arg.constant_value.dtype(), + arg.constant_value.shape()); + scalar_copy.flat() = arg.constant_value.flat(); + arg.constant_value = std::move(scalar_copy); + arg.constant_value.flat()(0) = filled_batch; + } + } + }; + // We rewrite only dynamic dimensions to the padded compile batch and then + // restore the original runtime sizes after compilation. Some scalar + // constants are actually runtime batch sizes folded by earlier TF passes, + // so rewrite only those that match the detected dynamic runtime value. + // Scalar constants are deep-copied before rewrite so the change stays + // local to norm_args and does not require restoration. if (filled_batch) { for (int i = 0; i < norm_args.size(); ++i) { - auto& arg = norm_args[i]; - // argument rewrite. - if (arg.dynamic_dim == 0) { - TensorShape& shp = std::get(arg.shape); - int64_t old = shp.dim_size(0); - shp.set_dim(0, filled_batch); - } - // constant argument rewrite otherwise it still store the incoming batch - // request. - if (arg.kind == XlaCompiler::Argument::kConstant) { - auto flat = arg.constant_value.flat(); - int32 old_batch = flat(0); - flat(0) = static_cast(filled_batch); + TensorShape& shp = std::get(norm_args[i].shape); + for (int j = 0; j < shp.get_expressions().size(); ++j) { + auto e = shp.get_expression(j); + if (e->is_dynamic()) { + int64_t old = shp.dim_size(j); + old_vars.push_back({i, j, old}); + xla::DynExpr* padded_expr = xla::DynExpr::_(filled_batch); + xla::DynExpr* subst_expr = e->substitute(1, padded_expr)->s(); + int64_t new_dim = subst_expr->get_val(); + if (new_dim >= 0) { + shp.set_dim(j, new_dim); + // Necessary because set_dim removes the expression: + shp.set_expression(j, e); + } + } } + maybe_rewrite_scalar_constant(i); } } - - return xla_device_compiler->CompileIfNeeded( + auto status = xla_device_compiler->CompileIfNeeded( options, function, norm_args, compile_options, compile_mode, profiler, compilation_result, executable); + // Restore the original runtime dimensions after compilation. + if (filled_batch) { + for (const auto& old_var : old_vars) { + TensorShape& shp = + std::get(norm_args[old_var.arg_index].shape); + shp.set_dim(old_var.dyn_dim, old_var.old_value); + } + } + return status; } else { return xla_device_compiler->CompileIfNeeded( options, function, args, compile_options, compile_mode, profiler, @@ -937,7 +1143,6 @@ XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) void XlaRunOp::Compute(OpKernelContext* ctx) { VLOG(3) << "XlaRunOp " << def().name(); Tensor key_tensor = ctx->input(ctx->num_inputs() - 1); - bool use_pjrt = GetXlaOpsCommonFlags() ->tf_xla_use_device_api.IsEnabledInXlaCompileAndRunForDevice( @@ -1021,22 +1226,64 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); if (flags->tf_xla_enable_dynamic_sizes) { - BatchSizeResource* bsr = nullptr; - ScopedStepContainer* step_container = ctx->step_container(); - - absl::Status st = step_container->Lookup( - ctx->resource_manager(), BatchSizeResourceName, &bsr); - - if (st.ok()) { - run_options.set_batch_size(bsr->GetBatchSize()); - VLOG(1) << "run_options.batch_size is set to: " - << run_options.batch_size() << ". step_id: " << ctx->step_id(); - bsr->Unref(); - - } else if (IsNotFound(st)) { - VLOG(1) << "Warning: Not found BatchSizeResource in step_container."; + bool is_set = false; + std::set dyn_vals; + const auto* comp_result = closure.compilation_result(); + const int num_constant_args = closure.num_constant_args(); + for (int i = 0; i < comp_result->xla_input_shapes.size(); i++) { + const auto& xla_shape = closure.compilation_result()->xla_input_shapes[i]; + if (!xla_shape.IsArray() || xla_shape.expressions().empty()) continue; + + for (int dim = 0; dim < xla_shape.expressions().size(); dim++) { + xla::DynExpr* expr = xla_shape.expressions(dim); + if (expr && expr->is_dynamic()) { + int input_idx = comp_result->input_mapping[i] - num_constant_args; + if (input_idx < 0 || input_idx >= ctx->num_inputs()) { + VLOG(1) << "Warning: Input index is out of range"; + continue; + } + VLOG(1) << "input shape is " << ctx->input(input_idx).shape() + << ", corresponding xla input shape is " << xla_shape; + int64_t size = ctx->input(input_idx).shape().dim_size(dim); + int64_t dyn_val = expr->solve(size); // TODO: check if the result is correct later. + VLOG(1) << "Found dynamic input. Real size is: " << size + << ", solved dynamic value is " << dyn_val; + if (dyn_val == -1) { + VLOG(1) << "Warning: Failed to solve the expression"; + continue; + } + dyn_vals.insert(dyn_val); + } + } + } + + if (dyn_vals.size() == 1) { + run_options.set_batch_size(*(dyn_vals.begin())); + is_set = true; } else { - OP_REQUIRES_OK(ctx, st); + // Found multiple variables + VLOG(1) << "Warning: Found multiple variables"; + } + + if (!is_set) { + // TODO: Fallback to BatchSizeResource for now. Remove it later. + BatchSizeResource* bsr = nullptr; + ScopedStepContainer* step_container = ctx->step_container(); + + absl::Status st = step_container->Lookup( + ctx->resource_manager(), BatchSizeResourceName, &bsr); + + if (st.ok()) { + run_options.set_batch_size(bsr->GetBatchSize()); + VLOG(1) << "run_options.batch_size is set to: " + << run_options.batch_size() << ". step_id: " << ctx->step_id(); + bsr->Unref(); + + } else if (IsNotFound(st)) { + VLOG(1) << "Warning: Not found BatchSizeResource in step_container."; + } else { + OP_REQUIRES_OK(ctx, st); + } } } @@ -1068,7 +1315,8 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { launch_context.PopulateOutputs( ctx, closure.compilation_result(), execution_output->ConsumeResult(), /*missing_ctx_input_prefix=*/closure.num_constant_args(), - absl::MakeSpan(*variable_infos), input_output_alias, snapshot_ptrs)); + absl::MakeSpan(*variable_infos), input_output_alias, snapshot_ptrs, + &run_options)); } XlaMergeOp::XlaMergeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 7ed57392c0ce9f..0b33abc4612931 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -70,6 +70,9 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/dump_graph.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/grappler_item.h" namespace tensorflow { @@ -1919,7 +1922,6 @@ absl::Status MarkForCompilationPassImpl::Run() { // MarkForCompilationPassImpl is not set up to run the subsequent phases. return absl::OkStatus(); } - TF_RETURN_IF_ERROR(RunEdgeContractionLoop()); TF_RETURN_IF_ERROR(DeclusterNodes()); TF_RETURN_IF_ERROR(CreateClusters()); diff --git a/tensorflow/compiler/jit/shape_inference.cc b/tensorflow/compiler/jit/shape_inference.cc index 7a5106aa69bbbf..5f270d701a9c44 100644 --- a/tensorflow/compiler/jit/shape_inference.cc +++ b/tensorflow/compiler/jit/shape_inference.cc @@ -49,10 +49,17 @@ absl::Status ShapeHandleToTensorShape( if (!context->RankKnown(handle)) return absl::OkStatus(); std::vector dims(context->Rank(handle)); + std::vector dyn_exprs(context->Rank(handle)); for (int32_t i = 0, end = dims.size(); i < end; ++i) { dims[i] = context->Value(context->Dim(handle, i)); + auto ratio = context->DynamicRatio(context->Dim(handle, i)); + dyn_exprs[i] = ratio > 0 ? (ratio * *xla::DynExpr::V(1))->s() + : xla::DynExpr::_(dims[i]); // For now } - return PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape); + auto status = + PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape); + shape->set_expressions(dyn_exprs); + return status; } absl::Status PropagateShapes( diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 3f0e5df76b8c40..a045280b48b872 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -432,19 +432,30 @@ absl::Status XlaComputationLaunchContext::PopulateOutputs( } else { for (int i = 0; i < ctx->num_outputs(); ++i) { xla::Shape output_host_shape = output.on_host_shape(); - const xla::Shape& subshape = xla::ShapeUtil::GetSubshape(output_host_shape, {i}); + const xla::Shape& subshape = + xla::ShapeUtil::GetSubshape(output_host_shape, {i}); VLOG(2) << "PopulateOutputs: subshape[" << i << "]: "<< subshape; TensorShape shape; TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape)); - if (subshape.outer_multiplier() > 0) { - BatchSizeResource* bsr = nullptr; - ScopedStepContainer* step_container = ctx->step_container(); - TF_RETURN_IF_ERROR(step_container->Lookup( - ctx->resource_manager(), BatchSizeResourceName, &bsr)); - auto bsm = bsr->GetBatchSize() * subshape.outer_multiplier() ; - shape.set_dim(0, bsm); + bool has_dynamic = false; + + for(int i = 0 ; i < subshape.expressions().size(); ++i){ + auto expr = subshape.expressions(i); + if (expr->is_dynamic()){ + has_dynamic = true; + BatchSizeResource* bsr = nullptr; + ScopedStepContainer* step_container = ctx->step_container(); + TF_RETURN_IF_ERROR(step_container->Lookup( + ctx->resource_manager(), BatchSizeResourceName, &bsr)); + xla::DynExpr* batch_size = xla::DynExpr::_(bsr->GetBatchSize()); + // Just substitute Var(1) for now. + xla::DynExpr* subst_expr = expr->substitute(1, batch_size)->s(); + shape.set_dim(i, subst_expr->get_val()); + bsr->Unref(); + } + } + if (has_dynamic) { output_tensor_shapes.push_back(shape); - bsr->Unref(); } else { output_tensor_shapes.push_back(compilation_result->outputs[i].shape); diff --git a/tensorflow/compiler/tf2xla/kernels/beta_op.cc b/tensorflow/compiler/tf2xla/kernels/beta_op.cc index 4ead9f76fcee11..9159df4f6a4ad5 100644 --- a/tensorflow/compiler/tf2xla/kernels/beta_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/beta_op.cc @@ -63,11 +63,14 @@ class BetaincOp : public XlaOpKernel { auto result = builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN( - auto a, BroadcastTo(ctx->Input(0), merged_shape.dim_sizes())); + auto a, BroadcastTo(ctx->Input(0), merged_shape.dim_sizes(), + merged_shape.get_expressions())); TF_ASSIGN_OR_RETURN( - auto b, BroadcastTo(ctx->Input(1), merged_shape.dim_sizes())); + auto b, BroadcastTo(ctx->Input(1), merged_shape.dim_sizes(), + merged_shape.get_expressions())); TF_ASSIGN_OR_RETURN( - auto x, BroadcastTo(ctx->Input(2), merged_shape.dim_sizes())); + auto x, BroadcastTo(ctx->Input(2), merged_shape.dim_sizes(), + merged_shape.get_expressions())); return xla::RegularizedIncompleteBeta(a, b, x); }); ctx->SetOutput(0, result); diff --git a/tensorflow/compiler/tf2xla/kernels/bincount_op.cc b/tensorflow/compiler/tf2xla/kernels/bincount_op.cc index fadd2f87219464..4f31c79f91a719 100644 --- a/tensorflow/compiler/tf2xla/kernels/bincount_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/bincount_op.cc @@ -117,9 +117,13 @@ class DenseBincountOp : public XlaOpKernel { xla::ShapeUtil::MakeShape(input_xla_type, {input_shape.dimensions()}); auto i = xla::Iota(ctx->builder(), i_shape, 0); i = xla::Reshape( - i, {input_shape.dimensions(0) * input_shape.dimensions(1), 1}); + i, {input_shape.dimensions(0) * input_shape.dimensions(1), 1}, + {(*input_shape.expressions(0) * *input_shape.expressions(1))->s(), + xla::DynExpr::one}); auto j = xla::Reshape( - input, {input_shape.dimensions(0) * input_shape.dimensions(1), 1}); + input, {input_shape.dimensions(0) * input_shape.dimensions(1), 1}, + {(*input_shape.expressions(0) * *input_shape.expressions(1))->s(), + xla::DynExpr::one}); std::vector iotas_to_concat; iotas_to_concat.push_back(i); iotas_to_concat.push_back(j); @@ -130,7 +134,8 @@ class DenseBincountOp : public XlaOpKernel { zero, {output_shape.dimensions(0), output_shape.dimensions(1)}); if (has_weights && !binary_output_) { weights = xla::Reshape( - weights, {input_shape.dimensions(0) * input_shape.dimensions(1)}); + weights, {input_shape.dimensions(0) * input_shape.dimensions(1)}, + {(*input_shape.expressions(0) * *input_shape.expressions(1))->s()}); updates = weights; } } else { diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc index 975179466bf104..8176c4c820cb40 100644 --- a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc @@ -37,7 +37,8 @@ class BroadcastToOp : public XlaOpKernel { context->ConstantInputAsShape( 1, &output_shape, xla::ValueInferenceMode::kUpperBound)); auto output_status_or = - BroadcastTo(context->Input(0), output_shape.dim_sizes()); + BroadcastTo(context->Input(0), output_shape.dim_sizes(), + output_shape.get_expressions()); OP_REQUIRES_OK(context, output_status_or.status()); auto output = output_status_or.value(); std::vector dynamic_dims; diff --git a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc index 6b4f278c72beff..e3459681a98b13 100644 --- a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc @@ -46,11 +46,11 @@ class ClipByValueOp : public XlaOpKernel { if (shape != min_shape) { OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(min_shape), shape_error()); - min = xla::Broadcast(min, shape.dim_sizes()); + min = xla::Broadcast(min, shape.dim_sizes(), shape.get_expressions()); } if (shape != max_shape) { OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(max_shape), shape_error()); - max = xla::Broadcast(max, shape.dim_sizes()); + max = xla::Broadcast(max, shape.dim_sizes(), shape.get_expressions()); } ctx->SetOutput(0, xla::Clamp(min, input, max)); } diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index d2463a9974b1bb..043a263b9f2605 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -123,7 +123,8 @@ class ConstOp : public XlaOpKernel { if (shape.num_elements() > 1) { xla::XlaOp value = GetScalarConst(proto_, b); if (value.valid()) { - ctx->SetOutput(0, xla::Broadcast(value, shape.dim_sizes())); + ctx->SetOutput(0, xla::Broadcast(value, shape.dim_sizes(), + shape.get_expressions())); return; } } diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index 3fe22dcb4441e7..466707f0d777e2 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -97,8 +97,11 @@ xla::XlaOp TransposeFilterForGroupConvolutionBackpropInput( CHECK_GE(num_dims, 2); // Crash OK xla::Shape new_shape = filter_shape; new_shape.set_dimensions(num_dims - 1, num_groups); - new_shape.add_dimensions(filter_shape.dimensions(num_dims - 1) / num_groups); - xla::XlaOp result = xla::Reshape(filter, new_shape.dimensions()); + new_shape.add_dimensions( + filter_shape.dimensions(num_dims - 1) / num_groups, + (*filter_shape.expressions(num_dims - 1) / num_groups)->s()); + xla::XlaOp result = + xla::Reshape(filter, new_shape.dimensions(), new_shape.expressions()); // 2. Transpose to [H, W, ..., G, filter_in_depth, out_depth / G] std::vector transpose_dims(num_dims + 1); @@ -118,7 +121,8 @@ xla::XlaOp ReshapeFilterForDepthwiseConvolution(const xla::Shape& filter_shape, xla::XlaOp filter) { return xla::Reshape( filter, - GroupedFilterShapeForDepthwiseConvolution(filter_shape).dimensions()); + GroupedFilterShapeForDepthwiseConvolution(filter_shape).dimensions(), + GroupedFilterShapeForDepthwiseConvolution(filter_shape).expressions()); } // Performs some basic checks on ConvOpAttrs that are true for all kinds of XLA @@ -603,7 +607,8 @@ absl::StatusOr MakeXlaBackpropFilterConvOp( } if (attrs.depthwise) { - filter_backprop = xla::Reshape(filter_backprop, filter_shape.dimensions()); + filter_backprop = xla::Reshape(filter_backprop, filter_shape.dimensions(), + filter_shape.expressions()); } return filter_backprop; diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index b1da0acd61608f..f90bad207cfbd4 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -110,7 +110,8 @@ class ConvNDOp : public XlaOpKernel { expanded_input_shape.set_dimensions(i + 1, input_shape.dimensions(i)); } expanded_input_shape.set_dimensions(0, 1); - input = xla::Reshape(input, expanded_input_shape.dimensions()); + input = xla::Reshape(input, expanded_input_shape.dimensions(), + expanded_input_shape.expressions()); } else if (attrs_.batch_dims > 1) { // Flatten batch_dims. std::vector to_collapse(attrs_.batch_dims); @@ -131,7 +132,8 @@ class ConvNDOp : public XlaOpKernel { if (attrs_.batch_dims == 0) { xla::Shape no_batch_shape(out_shape); no_batch_shape.DeleteDimension(0); - out = xla::Reshape(out, no_batch_shape.dimensions()); + out = xla::Reshape(out, no_batch_shape.dimensions(), + no_batch_shape.expressions()); } else if (attrs_.batch_dims > 1) { xla::Shape expanded_out_shape(input_shape); for (int i = attrs_.batch_dims; i < input_shape.dimensions().size(); @@ -139,7 +141,8 @@ class ConvNDOp : public XlaOpKernel { expanded_out_shape.set_dimensions( i, out_shape.dimensions(i - (attrs_.batch_dims - 1))); } - out = xla::Reshape(out, expanded_out_shape.dimensions()); + out = xla::Reshape(out, expanded_out_shape.dimensions(), + expanded_out_shape.expressions()); } ctx->SetOutput(0, out); } diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index 404fa9f5e04e45..d8740aa137fbeb 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -103,7 +103,7 @@ class DiagOp : public XlaOpKernel { // Flattens the input to 1D. int64_t size = input_shape.num_elements(); - input = xla::Reshape(input, {size}); + input = xla::Reshape(input, {size}, {}); // Create an R2 with the R1 diagonal. xla::XlaOp diag = CreateDiagonal(input, size, /*other_dims=*/{}); diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc index ceeea010ee7858..84c091339c1804 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc @@ -140,8 +140,9 @@ class DynamicPartitionOp : public XlaOpKernel { for (int64_t i = 0; i < rank; ++i) { broadcasted_dims.push_back(i); } - partitions = xla::BroadcastInDim(partitions, data_shape.dimensions(), - broadcasted_dims); + partitions = + xla::BroadcastInDim(partitions, data_shape.dimensions(), + broadcasted_dims, data_shape.expressions()); } // Output shape bounded is calculated by diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc index 2a65441eb79bf9..d19e9f76f2c9b0 100644 --- a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc @@ -127,7 +127,8 @@ class FakeQuantWithMinMaxArgsGradOp : public XlaOpKernel { xla::XlaOp between_nudged_min_max = xla::And( xla::Le(nudged_input_min, input), xla::Le(input, nudged_input_max)); xla::XlaOp zeroes = xla::Broadcast(XlaHelpers::Zero(b, data_type), - gradient_shape.dim_sizes()); + gradient_shape.dim_sizes(), + gradient_shape.get_expressions()); xla::XlaOp output = xla::Select(between_nudged_min_max, gradient, zeroes); ctx->SetOutput(0, output); } @@ -213,7 +214,8 @@ class FakeQuantWithMinMaxVarsGradOp : public XlaOpKernel { xla::XlaOp between_nudged_min_max = xla::And( xla::Le(nudged_input_min, input), xla::Le(input, nudged_input_max)); xla::XlaOp zero = XlaHelpers::Zero(b, data_type); - xla::XlaOp zeroes = xla::Broadcast(zero, gradient_shape.dim_sizes()); + xla::XlaOp zeroes = xla::Broadcast(zero, gradient_shape.dim_sizes(), + gradient_shape.get_expressions()); xla::XlaOp output0 = xla::Select(between_nudged_min_max, gradient, zeroes); ctx->SetOutput(0, output0); @@ -268,9 +270,12 @@ class FakeQuantWithMinMaxVarsPerChannelOp : public XlaOpKernel { xla::Shape input_shape = b->GetShape(input).value(); absl::Span input_dimensions = input_shape.dimensions(); + absl::Span input_expressions = + input_shape.expressions(); auto convert_to_input_shape = [&](const xla::XlaOp op) { return xla::BroadcastInDim(op, input_dimensions, - {input_shape.dimensions_size() - 1}); + {input_shape.dimensions_size() - 1}, + input_expressions); }; input_min = convert_to_input_shape(input_min); input_max = convert_to_input_shape(input_max); @@ -323,6 +328,8 @@ class FakeQuantWithMinMaxVarsPerChannelGradOp : public XlaOpKernel { xla::XlaBuilder* b = ctx->builder(); xla::Shape input_shape = b->GetShape(input).value(); absl::Span input_dimensions = input_shape.dimensions(); + absl::Span input_expressions = + input_shape.expressions(); std::vector reduce_axes; for (int64_t i = 0; i + 1 < input_shape.dimensions_size(); ++i) { @@ -331,7 +338,8 @@ class FakeQuantWithMinMaxVarsPerChannelGradOp : public XlaOpKernel { auto convert_to_input_shape = [&](const xla::XlaOp op) { return xla::BroadcastInDim(op, input_dimensions, - {input_shape.dimensions_size() - 1}); + {input_shape.dimensions_size() - 1}, + input_expressions); }; input_min = convert_to_input_shape(input_min); input_max = convert_to_input_shape(input_max); @@ -343,7 +351,8 @@ class FakeQuantWithMinMaxVarsPerChannelGradOp : public XlaOpKernel { xla::XlaOp between_nudged_min_max = xla::And( xla::Le(nudged_input_min, input), xla::Le(input, nudged_input_max)); xla::XlaOp zero = XlaHelpers::Zero(b, data_type); - xla::XlaOp zeroes = xla::Broadcast(zero, gradient_shape.dim_sizes()); + xla::XlaOp zeroes = xla::Broadcast(zero, gradient_shape.dim_sizes(), + gradient_shape.get_expressions()); xla::XlaOp output0 = xla::Select(between_nudged_min_max, gradient, zeroes); ctx->SetOutput(0, output0); diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 2783951e1b6b0f..e87155821c523f 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -88,7 +88,8 @@ absl::Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, out_shape.AppendShape(input_shape_post_axis); *gather_output = - xla::Broadcast(XlaHelpers::Zero(builder, dtype), out_shape.dim_sizes()); + xla::Broadcast(XlaHelpers::Zero(builder, dtype), out_shape.dim_sizes(), + out_shape.get_expressions()); return absl::OkStatus(); } diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index a8eb7bbf794268..162ce34488a620 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -59,7 +59,7 @@ std::array RGBToHSV(XlaOpKernelContext* ctx, xla::XlaBuilder* b, auto minimum = xla::Min(xla::Min(red, green), blue); auto range = xla::Sub(value, minimum); - auto zeros = xla::Broadcast(zero, shape.dim_sizes()); + auto zeros = xla::Broadcast(zero, shape.dim_sizes(), shape.get_expressions()); auto saturation = xla::Select(xla::Gt(value, zero), xla::Div(range, value), zeros); diff --git a/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc b/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc index f357262a39c35b..e6e90965dbffcf 100644 --- a/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc @@ -87,9 +87,11 @@ class InTopKOp : public XlaOpKernel { // which indicates the target is in topk. xla::XlaOp gt_r2 = xla::Gt(predictions_r2, targets_values_r1, {0}); xla::XlaOp zero_r0 = xla::Zero(xla_builder, xla::S32); - xla::XlaOp zero_r2 = xla::Broadcast(zero_r0, predictions_shape.dim_sizes()); + xla::XlaOp zero_r2 = xla::Broadcast(zero_r0, predictions_shape.dim_sizes(), + predictions_shape.get_expressions()); xla::XlaOp one_r0 = xla::One(xla_builder, xla::S32); - xla::XlaOp one_r2 = xla::Broadcast(one_r0, predictions_shape.dim_sizes()); + xla::XlaOp one_r2 = xla::Broadcast(one_r0, predictions_shape.dim_sizes(), + predictions_shape.get_expressions()); xla::XlaOp one_hot_r2 = xla::Select(gt_r2, one_r2, zero_r2); xla::XlaOp num_gt_r1 = xla::Reduce( one_hot_r2, zero_r0, diff --git a/tensorflow/compiler/tf2xla/kernels/lower_upper_bound_ops.cc b/tensorflow/compiler/tf2xla/kernels/lower_upper_bound_ops.cc index 46e46f6d8b3d32..3ff4c4b0c6a221 100644 --- a/tensorflow/compiler/tf2xla/kernels/lower_upper_bound_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/lower_upper_bound_ops.cc @@ -49,14 +49,16 @@ void BuildLowerUpperBoundOp(XlaOpKernelContext* ctx, DataType out_dtype, // dimension of sorted_sequence. auto new_values_shape = values_shape; new_values_shape.InsertDim(/* d */ 2, /* size */ 1); - auto values_reshaped = xla::Reshape(values, new_values_shape.dim_sizes()); + auto values_reshaped = xla::Reshape(values, new_values_shape.dim_sizes(), + new_values_shape.get_expressions()); // Add a new penultimate dimension to sorted_inputs, to allow broadcasting of // sorted_sequence entries for each value. auto new_sorted_inputs_shape = sorted_inputs_shape; new_sorted_inputs_shape.InsertDim(/* d */ 1, /* size */ 1); auto sorted_inputs_reshaped = - xla::Reshape(sorted_inputs, new_sorted_inputs_shape.dim_sizes()); + xla::Reshape(sorted_inputs, new_sorted_inputs_shape.dim_sizes(), + new_sorted_inputs_shape.get_expressions()); // We are relying on broadcasting to compare each value against each entry in // the associated sorted_inputs row. diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc b/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc index 48e8f976cc67bb..ac798db3460ede 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc @@ -234,7 +234,8 @@ xla::XlaOp SetMatrixDiag(const xla::XlaOp input, const xla::XlaOp diag, // Broadcast and mask. xla::XlaOp diag_broadcast = xla::BroadcastInDim( - diag_slice, input_shape.dim_sizes(), broadcast_dimensions); + diag_slice, input_shape.dim_sizes(), broadcast_dimensions, + input_shape.get_expressions()); const auto mask = xla::GetDiagonalMask(output, diag_index); output = xla::Select(mask, diag_broadcast, output); } @@ -328,7 +329,8 @@ class MatrixDiagOp : public XlaOpKernel { output_shape.RemoveLastDims((num_diags == 1) ? 1 : 2); output_shape.AddDim(num_rows); output_shape.AddDim(num_cols); - xla::XlaOp output = xla::Broadcast(padding_value, output_shape.dim_sizes()); + xla::XlaOp output = xla::Broadcast(padding_value, output_shape.dim_sizes(), + output_shape.get_expressions()); xla::XlaOp diag = context->Input(0); context->SetOutput( 0, SetMatrixDiag(output, diag, output_shape, diag_rank, num_diags, @@ -447,7 +449,8 @@ class MatrixDiagPartOp : public XlaOpKernel { } auto concat = xla::ConcatInDim(context->builder(), diag_list, input_rank - 2); - context->SetOutput(0, xla::Reshape(concat, output_shape.dim_sizes())); + context->SetOutput(0, xla::Reshape(concat, output_shape.dim_sizes(), + output_shape.get_expressions())); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc index 17b5ae7a70375a..a6d618d1b7889e 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc @@ -97,7 +97,8 @@ MatrixTriangularSolveOp::Broadcast(xla::XlaOp lhs, const TensorShape& lhs_shape, TensorShape lhs_broadcast_shape(broadcast_helper.output_batch_shape()); lhs_broadcast_shape.AddDim(m); lhs_broadcast_shape.AddDim(m); - auto lhs_output = BroadcastTo(lhs, lhs_broadcast_shape.dim_sizes()); + auto lhs_output = BroadcastTo(lhs, lhs_broadcast_shape.dim_sizes(), + lhs_broadcast_shape.get_expressions()); if (!lhs_output.ok()) { xla::XlaOp error = lhs.builder()->ReportError(lhs_output.status()); return {error, error}; @@ -106,7 +107,8 @@ MatrixTriangularSolveOp::Broadcast(xla::XlaOp lhs, const TensorShape& lhs_shape, TensorShape rhs_broadcast_shape(broadcast_helper.output_batch_shape()); rhs_broadcast_shape.AddDim(m); rhs_broadcast_shape.AddDim(n); - auto rhs_output = BroadcastTo(rhs, rhs_broadcast_shape.dim_sizes()); + auto rhs_output = BroadcastTo(rhs, rhs_broadcast_shape.dim_sizes(), + rhs_broadcast_shape.get_expressions()); if (!rhs_output.ok()) { xla::XlaOp error = rhs.builder()->ReportError(rhs_output.status()); return {error, error}; diff --git a/tensorflow/compiler/tf2xla/kernels/pack_op.cc b/tensorflow/compiler/tf2xla/kernels/pack_op.cc index ba4e8bbef7b136..3d84437dba7869 100644 --- a/tensorflow/compiler/tf2xla/kernels/pack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pack_op.cc @@ -66,9 +66,17 @@ class PackOp : public XlaOpKernel { TensorShape child_shape(shapes[0]); child_shape.InsertDim(axis, 1); + // Equivalent to InsertDim(axis, 1) for expressions + std::vector exprs; + for (auto e : child_shape.get_expressions()) { + exprs.push_back(e); + } + exprs.insert(exprs.begin() + axis, xla::DynExpr::one); + for (int i = 0; i < num; ++i) { // Reshape the inputs to have an extra dimension of size 1. - reshaped_inputs[i] = xla::Reshape(values[i], child_shape.dim_sizes()); + reshaped_inputs[i] = xla::Reshape(values[i], child_shape.dim_sizes(), + exprs); } ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), reshaped_inputs, axis)); diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index aa7c78b8b8f97a..b3cdb7bac0c7dc 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -264,10 +264,14 @@ class MaxPoolOp : public PoolingOp { absl::InlinedVector new_dims(result_shape->dimensions().begin(), result_shape->dimensions().end()); + absl::InlinedVector new_exprs( + result_shape->expressions().begin(), + result_shape->expressions().end()); new_dims[1] /= *vect_width; + new_exprs[1] = *new_exprs[1] / *vect_width; new_dims.insert(new_dims.begin() + 2, *vect_width); - pooling = - xla::Transpose(xla::Reshape(pooling, new_dims), {0, 1, 3, 4, 2}); + pooling = xla::Transpose(xla::Reshape(pooling, new_dims, new_exprs), + {0, 1, 3, 4, 2}); } ctx->SetOutput(0, pooling); diff --git a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc index cac9f8a68f234e..ba6860caa9cb1d 100644 --- a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc @@ -173,8 +173,11 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { if (!xla::ShapeUtil::IsScalar(axis_shape)) { xla::Shape input_shape = b->GetShape(input).value(); absl::Span input_dimensions = input_shape.dimensions(); + absl::Span input_expressions = + input_shape.expressions(); auto convert_to_input_shape = [&](const xla::XlaOp op) { - return xla::BroadcastInDim(op, input_dimensions, {axis_}); + return xla::BroadcastInDim(op, input_dimensions, {axis_}, + input_expressions); }; min_range = convert_to_input_shape(min_range); max_range = convert_to_input_shape(max_range); diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index 6a8a98342c1123..9ed1116c08dc7a 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -106,16 +106,19 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { } std::vector final_shape; + std::vector final_exprs; for (int i = 0; i < data_shape.dims(); ++i) { if (!bitmap[i]) { // If we are not reducing along dimension i. int64_t dim = data_shape.dim_size(i); final_shape.push_back(dim); + final_exprs.push_back(data_shape.get_expression(i)); } else if (keep_dims_) { // We are reducing along dimension i, but we want to keep the // same number of dimensions, so we set the dimension of i to // '1'. final_shape.push_back(1); + final_exprs.push_back(xla::DynExpr::one); } } @@ -139,7 +142,8 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { auto reduce = xla::Reduce(data, initial, reduction_computation, xla_axes); auto finalized = BuildFinalizer(b, data, reduce, xla_axes); - auto result = keep_dims_ ? xla::Reshape(finalized, final_shape) : finalized; + auto result = keep_dims_ ? xla::Reshape(finalized, final_shape, final_exprs) + : finalized; ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc index f274b271596ff5..9a7c875347f120 100644 --- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc @@ -40,8 +40,28 @@ XlaOp Relu6(XlaOp x) { namespace tensorflow { namespace { -REGISTER_XLA_OP(Name("Relu"), MlirXlaOpKernel); -REGISTER_XLA_OP(Name("Relu6"), MlirXlaOpKernel); +class ReluOp : public XlaOpKernel { + public: + explicit ReluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + // Computes the max of the scalar input x and 0. + void Compile(XlaOpKernelContext* ctx) override { + ctx->SetOutput(0, xla::Relu(ctx->Input(0))); + } +}; + +class Relu6Op : public XlaOpKernel { + public: + explicit Relu6Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + // Computes the max of the scalar input x and 0. + void Compile(XlaOpKernelContext* ctx) override { + ctx->SetOutput(0, xla::Relu6(ctx->Input(0))); + } +}; + +REGISTER_XLA_OP(Name("Relu"), ReluOp); +// REGISTER_XLA_OP(Name("Relu"), MlirXlaOpKernel); +// REGISTER_XLA_OP(Name("Relu6"), MlirXlaOpKernel); +REGISTER_XLA_OP(Name("Relu6"), Relu6Op); class LeakyReluOp : public XlaOpKernel { public: @@ -70,9 +90,11 @@ class Relu6GradOp : public XlaOpKernel { xla::XlaBuilder* b = ctx->builder(); const TensorShape shape = ctx->InputShape(0); const auto zero = - xla::Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); - const auto six = xla::Broadcast( - XlaHelpers::IntegerLiteral(b, input_type(0), 6), shape.dim_sizes()); + xla::Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes(), + shape.get_expressions()); + const auto six = + xla::Broadcast(XlaHelpers::IntegerLiteral(b, input_type(0), 6), + shape.dim_sizes(), shape.get_expressions()); auto out = xla::Select( xla::And(xla::Lt(ctx->Input(1), six), xla::Gt(ctx->Input(1), zero)), ctx->Input(0), zero); diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index ba17d1b295b763..cadbafe2d27fbd 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -57,8 +57,10 @@ class ReshapeOp : public XlaOpKernel { // is one. TensorShape shape; int64_t product = 1; + xla::DynExpr* product_expr = xla::DynExpr::one; int unknown_index = -1; bool shape_has_zero_dim = false; + int ratio = 1; for (int d = 0; d < num_dims; ++d) { const int64_t size = shape_input[d]; if (size == -1) { @@ -68,23 +70,58 @@ class ReshapeOp : public XlaOpKernel { unknown_index, " and ", d)); unknown_index = d; shape.AddDim(1); + shape.AddExpression(xla::DynExpr::one); } else if (size == 0) { // We don't include zero-sized dimension in product, so that we can // still calculate number of elements for non-zero-sized dimensions and // therefore infer their shapes. shape.AddDim(size); + shape.AddExpression(xla::DynExpr::_(size)); shape_has_zero_dim = true; } else { + xla::DynExpr* size_expr; OP_REQUIRES(ctx, size >= 0, errors::InvalidArgument( "size ", d, " must be non-negative, not ", size)); shape.AddDim(size); + if (d < input_shape.dims() && + input_shape.get_expression(d)->is_dynamic()) { + int old = input_shape.dim_size(d); + bool is_split = (old > size); + int local_ratio = ratio * (is_split ? old / size : size / old); + xla::DynExpr* new_expr = + (size > old) + ? *input_shape.get_expression(d) * + *xla::DynExpr::_(local_ratio) // Split [xy] -> [x/y,y] + : *input_shape.get_expression(d) / + *xla::DynExpr::_(local_ratio); // Reduce [x,y] -> [x*y] + + // Pass ratio to next dimension if this is a split, otherwise just + // reset it to 1. + ratio = is_split ? local_ratio : 1; + size_expr = new_expr->s(); + + } else { + if (ratio == 1){ // Nothing has been previously split. + size_expr = xla::DynExpr::_(size); + } else if (ratio == size) { // The factor of the previous split is + // the new dimension. + size_expr = xla::DynExpr::_(size); + ratio = 1; // reset ratio + } else { + // Should not happen. + size_expr = xla::DynExpr::_(-50); + } + } + shape.AddExpression(size_expr); product *= size; + product_expr = (*product_expr * *size_expr); } } auto input = ctx->Input(0); if (unknown_index != -1) { int64_t input_num_elements = 1; + xla::DynExpr* input_num_elements_expr = xla::DynExpr::one; bool input_has_zero_dim = false; for (int dim = 0; dim < input_shape.dims(); dim++) { // For zero dimension, we don't count it into `input_num_elements` @@ -92,12 +129,17 @@ class ReshapeOp : public XlaOpKernel { // infer shapes for other dimensions. if (input_shape.dim_size(dim) > 0 || !shape_has_zero_dim) { input_num_elements *= input_shape.dim_size(dim); + input_num_elements_expr = + (*input_num_elements_expr * *input_shape.get_expression(dim))->s(); } else { input_has_zero_dim = true; } } int64_t missing = input_num_elements / product; + input_num_elements_expr = input_num_elements_expr->s(); + product_expr = product_expr->s(); + auto missing_expr = *input_num_elements_expr / *product_expr; if (!input_has_zero_dim) { if (input_xla_shape->is_static() || input_xla_shape->dimensions().size() != 1) { @@ -119,9 +161,14 @@ class ReshapeOp : public XlaOpKernel { input, xla::Zero(ctx->builder(), input_xla_shape->element_type()), 0, 0, padded_input_num - input_num_elements); input_shape.set_dim(0, padded_input_num); + input_shape.set_expression( + 0, xla::DynExpr::_( + padded_input_num)); // Issue here as it depends on ceil } } shape.set_dim(unknown_index, missing); + shape.set_expression( + unknown_index, missing_expr->s()); } OP_REQUIRES(ctx, shape.num_elements() == input_shape.num_elements(), errors::InvalidArgument("Input to reshape is a tensor with ", @@ -131,19 +178,23 @@ class ReshapeOp : public XlaOpKernel { VLOG(2) << "Reshape from " << input_shape.DebugString() << " to " << shape.DebugString() << ", unknown_index=" << unknown_index; + if (input_xla_shape->is_static()) { - ctx->SetOutput(0, xla::Reshape(input, shape.dim_sizes())); + ctx->SetOutput( + 0, xla::Reshape(input, shape.dim_sizes(), shape.get_expressions())); return; } std::vector output_dim_sizes; std::vector dims_are_dynamic; + std::vector expressions; const auto& dims = shape.dims(); dims_are_dynamic.reserve(dims); output_dim_sizes.reserve(dims); for (int64_t i = 0; i < dims; ++i) { output_dim_sizes.push_back( xla::Reshape(xla::Slice(ctx->Input(1), {i}, {i + 1}, {1}), {})); + expressions.push_back(xla::DynExpr::_(-10)); } OP_REQUIRES_OK( ctx, ctx->ResolveInputDynamismIntoPredVector(1, &dims_are_dynamic)); @@ -151,7 +202,7 @@ class ReshapeOp : public XlaOpKernel { // No unknown index. ctx->SetOutput( 0, xla::DynamicReshape(input, output_dim_sizes, shape.dim_sizes(), - dims_are_dynamic)); + dims_are_dynamic, expressions)); return; } auto common_factors = @@ -162,6 +213,7 @@ class ReshapeOp : public XlaOpKernel { auto start = common_factors[i]; auto end = common_factors[i + 1]; bool input_is_dynamic = false; + xla::DynExpr* expression = xla::DynExpr::_(-20); // product of all input dims in this group. E.g., in // reshape(Tensor([2, 3, 3]), [3, -1, 3]) product of the group // containing -1 will be 6. @@ -188,12 +240,13 @@ class ReshapeOp : public XlaOpKernel { // If input dim is dynamic, output dim at the -1 position must be // dynamic. Similarly, if input dim is static, output dim has to be // static at the -1 dimension. + expressions[unknown_index] = expression; dims_are_dynamic[unknown_index] = input_is_dynamic; output_dim_sizes[unknown_index] = unknown_dim_size; ctx->SetOutput( 0, xla::DynamicReshape(input, output_dim_sizes, shape.dim_sizes(), - dims_are_dynamic)); + dims_are_dynamic, expressions)); VLOG(2) << "Reshape from " << ctx->InputXlaShape(0)->ToString() << " to " << xla::VectorString(shape.dim_sizes()) << ", dynamic_dims=" << xla::VectorString(dims_are_dynamic); diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc index 694b4eb17ef298..95f52049872710 100644 --- a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc @@ -118,7 +118,8 @@ class ScatterNdOp : public XlaOpKernel { xla::XlaBuilder* builder = context->builder(); auto buffer = xla::Broadcast(XlaHelpers::Zero(builder, dtype), - buffer_shape.dim_sizes()); + buffer_shape.dim_sizes(), + buffer_shape.get_expressions()); auto indices = context->Input(0); auto updates = context->Input(1); auto combine = diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc index 21eaac25f058ed..faf6822b9a074d 100644 --- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc @@ -96,7 +96,8 @@ class SegmentReduce : public XlaOpKernel { buffer_shape.InsertDim(0, num_segments); auto buffer = - xla::Broadcast(InitialValue(builder), buffer_shape.dim_sizes()); + xla::Broadcast(InitialValue(builder), buffer_shape.dim_sizes(), + buffer_shape.get_expressions()); // Build dynamic dim sizes for buffer, as well as whether each dimension // size is dynamic or static. We build two parts: num_sgement part and diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc index 85aaabe87076c2..3241dfc61609a1 100644 --- a/tensorflow/compiler/tf2xla/kernels/select_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc @@ -70,8 +70,10 @@ class SelectOp : public XlaOpKernel { // Broadcast into the dimensions on the right. std::vector broadcast_dimensions(cond_shape.dims()); absl::c_iota(broadcast_dimensions, 0); + cond_handle = xla::BroadcastInDim(cond_handle, then_shape.dim_sizes(), - broadcast_dimensions); + broadcast_dimensions, + then_shape.get_expressions()); } ctx->SetOutput(0, xla::Select(cond_handle, then_handle, else_handle)); } @@ -81,7 +83,8 @@ class SelectOp : public XlaOpKernel { void operator=(const SelectOp&) = delete; }; -REGISTER_XLA_OP(Name("Select"), MlirXlaOpKernel); +// REGISTER_XLA_OP(Name("Select"), MlirXlaOpKernel); +REGISTER_XLA_OP(Name("Select"), SelectOp); class SelectOpV2 : public XlaOpKernel { public: diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 7e8889cb2ccee6..d376dfb9240876 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -290,6 +290,8 @@ class ExpandDimsOp : public XlaOpKernel { " dimensions.")); auto existing_dims = input_shape.dim_sizes(); + auto existing_exprs = input_shape.get_expressions(); + // Safe - # elements in tensor dims bounded. const int existing_dims_size = static_cast(existing_dims.size()); std::vector new_shape(existing_dims_size); @@ -297,6 +299,12 @@ class ExpandDimsOp : public XlaOpKernel { new_shape[i] = existing_dims[i]; } + const int existing_exprs_size = static_cast(existing_exprs.size()); + std::vector new_exprs(existing_exprs_size); + for (size_t i = 0; i < new_exprs.size(); ++i) { + new_exprs[i] = existing_exprs[i]; + } + // We emulate numpy's interpretation of the dim axis when // -input.dims() >= dim <= input.dims(). if (dim < 0) { @@ -306,8 +314,9 @@ class ExpandDimsOp : public XlaOpKernel { // Clamp to the end if needed. dim = std::min(dim, existing_dims_size); new_shape.emplace(new_shape.begin() + dim, 1); + new_exprs.emplace(new_exprs.begin() + dim, xla::DynExpr::one); - ctx->SetOutput(0, xla::Reshape(ctx->Input("input"), new_shape)); + ctx->SetOutput(0, xla::Reshape(ctx->Input("input"), new_shape, new_exprs)); } }; REGISTER_XLA_OP(Name("ExpandDims").CompileTimeConstantInput("dim"), @@ -430,7 +439,8 @@ class ZerosLikeOp : public XlaOpKernel { auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0)); xla::XlaOp input = ctx->Input(0); auto input_shape = ctx->InputXlaShape(0).value(); - auto result = xla::Broadcast(zero, input_shape.dimensions()); + auto result = xla::Broadcast(zero, input_shape.dimensions(), + input_shape.expressions()); // Setting up dynamic dimensions of the broadcast. for (int64_t i = 0; i < input_shape.dimensions().size(); ++i) { @@ -455,7 +465,8 @@ class OnesLikeOp : public XlaOpKernel { const TensorShape input_shape = ctx->InputShape(0); auto one = XlaHelpers::One(ctx->builder(), input_type(0)); - ctx->SetOutput(0, xla::Broadcast(one, input_shape.dim_sizes())); + ctx->SetOutput(0, xla::Broadcast(one, input_shape.dim_sizes(), + input_shape.get_expressions())); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index 844a31f97990fc..ef93f16c78cf42 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -97,13 +97,21 @@ class SliceOp : public XlaOpKernel { } } + std::vector begin_exprs; + for (int d : begin){ + begin_exprs.push_back(xla::DynExpr::_(d)); + } std::vector limits; + std::vector exprs; limits.reserve(begin.size()); + exprs.reserve(begin.size()); for (int i = 0; i < begin.size(); ++i) { limits.push_back(begin[i] + wrapped_size[i]); + exprs.push_back(xla::DynExpr::_(begin[i] + wrapped_size[i])); } std::vector strides(begin.size(), 1); - auto slice = xla::Slice(ctx->Input(0), begin, limits, strides); + auto slice = + xla::Slice(ctx->Input(0), begin, limits, begin_exprs, exprs, strides); // Check for slice on dynamic dimensions. std::vector size_is_dynamic; OP_REQUIRES_OK( @@ -114,8 +122,10 @@ class SliceOp : public XlaOpKernel { if (size[i] != -1) { // If there is a dynamic dimension, properly set dimension size of // the slice. - auto dynamic_size = - xla::Reshape(xla::Slice(ctx->Input(2), {i}, {i + 1}, {1}), {}); + auto dynamic_size = xla::Reshape( + xla::Slice(ctx->Input(2), {i}, {i + 1}, {xla::DynExpr::_(i)}, + {xla::DynExpr::_(i + 1)}, {1}), + {}); slice = xla::SetDimensionSize(slice, dynamic_size, i); } diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc index 330479bc8d4150..b1d7e16bbda741 100644 --- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc @@ -35,11 +35,72 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/bcast.h" +#include "tensorflow/compiler/tf2xla/type_util.h" + namespace tensorflow { namespace { -REGISTER_XLA_OP(Name("Softmax"), MlirXlaOpKernel); -REGISTER_XLA_OP(Name("LogSoftmax"), MlirXlaOpKernel); +class SoftmaxOp : public XlaOpKernel { + public: + explicit SoftmaxOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + log_ = absl::StartsWith(type_string(), "Log"); + } + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape logits_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(logits_shape), + errors::InvalidArgument("logits must have >= 1 dimension, got ", + logits_shape.DebugString())); + + // Major dimensions are batch dimensions, minor dimension is the class + // dimension. + std::vector batch_dims(logits_shape.dims() - 1); + std::iota(batch_dims.begin(), batch_dims.end(), 0); + const int kClassDim = logits_shape.dims() - 1; + + const DataType type = input_type(0); + const xla::PrimitiveType xla_type = ctx->input_xla_type(0); + auto logits = ctx->Input(0); + + xla::XlaBuilder* const b = ctx->builder(); + + const xla::XlaComputation& max_func = *ctx->GetOrCreateMax(type); + + // Find the max in each batch, resulting in a tensor of shape [batch] + auto logits_max = + xla::Reduce(logits, xla::MinValue(b, xla_type), max_func, {kClassDim}); + // Subtract the max in batch b from every element in batch b. Broadcasts + // along the batch dimension. + auto shifted_logits = xla::Sub(logits, logits_max, batch_dims); + auto exp_shifted = xla::Exp(shifted_logits); + const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); + xla::PrimitiveType xla_accumulation_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(accumulation_type, + &xla_accumulation_type)); + auto converted = + xla::ConvertElementType(exp_shifted, xla_accumulation_type); + auto reduce = + xla::Reduce(converted, xla::Zero(b, xla_accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); + auto sum = XlaHelpers::ConvertElementType(reduce, type); + auto softmax = + log_ + // softmax = shifted_logits - log(sum(exp(shifted_logits))) + ? xla::Sub(shifted_logits, xla::Log(sum), batch_dims) + // softmax = exp(shifted_logits) / sum(exp(shifted_logits)) + : xla::Div(exp_shifted, sum, batch_dims); + ctx->SetOutput(0, softmax); + } + + private: + bool log_; +}; + +REGISTER_XLA_OP(Name("Softmax"), SoftmaxOp); +REGISTER_XLA_OP(Name("LogSoftmax"), SoftmaxOp); std::pair CrossEntropyWithLogits( XlaOpKernelContext* ctx, DataType type, xla::PrimitiveType xla_type, diff --git a/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc b/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc index b4d589f183108e..1368326cf851df 100644 --- a/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc @@ -83,7 +83,8 @@ class SparseToDenseOp : public XlaOpKernel { sparse_values = Broadcast(sparse_values, {num_elems}); } xla::XlaBuilder* builder = context->builder(); - auto buffer = Broadcast(default_value, output_shape.dim_sizes()); + auto buffer = Broadcast(default_value, output_shape.dim_sizes(), + output_shape.get_expressions()); std::vector dynamic_dims; OP_REQUIRES_OK( context, context->ResolveInputDynamismIntoPredVector(1, &dynamic_dims)); diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index 4f7c4ae99b6b6b..482438cedfb42b 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -81,17 +81,21 @@ class SplitOp : public XlaOpKernel { // All the slices are the same size: this is the size along the // split dimension. const int32_t slice_size = input_shape.dim_size(split_dim) / num_split; + auto slice_expr = *input_shape.get_expression(split_dim) / num_split; // The vectors we will use to define the slice. The entry for the // split dimensions varies for each output. std::vector begin(input_shape.dims(), 0); std::vector limits(input_shape.dims()); + std::vector begin_expr(input_shape.dims(), xla::DynExpr::zero); + std::vector limits_expr(input_shape.dims()); std::vector strides(input_shape.dims(), 1); for (int i = 0; i < input_shape.dims(); ++i) { // Initially set up the limits to be the full size of the input: // the split dimension is filled in below. int64_t dim = input_shape.dim_size(i); limits[i] = dim; + limits_expr[i] = input_shape.get_expression(i); } // Create each of the outputs. @@ -99,7 +103,12 @@ class SplitOp : public XlaOpKernel { // Slice out the ith split from the split dimension. begin[split_dim] = i * slice_size; limits[split_dim] = (i + 1) * slice_size; - ctx->SetOutput(i, xla::Slice(input, begin, limits, strides)); + + begin_expr[split_dim] = i * *slice_expr; + limits_expr[split_dim] = (*xla::DynExpr::_(i + 1) * *slice_expr)->s(); + + ctx->SetOutput(i, xla::Slice(input, begin, limits, begin_expr, + limits_expr, strides)); } } }; diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index 3c99ad63565266..59716591ad08bc 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -157,7 +157,8 @@ class StackPushOp : public XlaOpKernel { TensorShape slice_shape = elem_shape; slice_shape.InsertDim(0, 1LL); - auto update = xla::Reshape(value, slice_shape.dim_sizes()); + auto update = xla::Reshape(value, slice_shape.dim_sizes(), + slice_shape.get_expressions()); // TODO(phawkins): We don't check the index is in bounds --- there is no // error mechanism in XLA. diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index aa71c5c34d2e1a..2f19d1f2dd9a03 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -49,8 +49,9 @@ xla::BitGeneratorTy GetBitGeneratorForDevice( device_type_string == DEVICE_CPU_XLA_JIT) { return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) { std::tie(state, key) = xla::ScramblePhiloxKey(key); - xla::XlaOp philox_state = - xla::ConcatInDim(key.builder(), {xla::Reshape(key, {1}), state}, 0); + xla::XlaOp philox_state = xla::ConcatInDim( + key.builder(), + {xla::Reshape(key, {1}, {xla::DynExpr::one}), state}, 0); xla::XlaOp result = xla::RngBitGenerator(xla::RandomAlgorithm::RNG_PHILOX, philox_state, shape); return xla::RngOutput{/*value=*/xla::GetTupleElement(result, 1), @@ -421,19 +422,23 @@ class StatelessParameterizedTruncatedNormalOp : public XlaOpKernel { xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape)); - auto bcasted_means = BroadcastTo(ctx->Input(2), shape.dim_sizes()); + auto bcasted_means = + BroadcastTo(ctx->Input(2), shape.dim_sizes(), shape.get_expressions()); OP_REQUIRES_OK(ctx, bcasted_means.status()); auto means = bcasted_means.value(); - auto bcasted_stddevs = BroadcastTo(ctx->Input(3), shape.dim_sizes()); + auto bcasted_stddevs = + BroadcastTo(ctx->Input(3), shape.dim_sizes(), shape.get_expressions()); OP_REQUIRES_OK(ctx, bcasted_stddevs.status()); auto stddevs = bcasted_stddevs.value(); - auto bcasted_minvals = BroadcastTo(ctx->Input(4), shape.dim_sizes()); + auto bcasted_minvals = + BroadcastTo(ctx->Input(4), shape.dim_sizes(), shape.get_expressions()); OP_REQUIRES_OK(ctx, bcasted_minvals.status()); auto minvals = bcasted_minvals.value(); - auto bcasted_maxvals = BroadcastTo(ctx->Input(5), shape.dim_sizes()); + auto bcasted_maxvals = + BroadcastTo(ctx->Input(5), shape.dim_sizes(), shape.get_expressions()); OP_REQUIRES_OK(ctx, bcasted_maxvals.status()); auto maxvals = bcasted_maxvals.value(); diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index e15196bd756462..0555e3cf79cdd7 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -82,6 +82,9 @@ class StridedSliceOp : public XlaOpKernel { partial_final_shape.set_dim( i, input_shape.dim_size(shape_spec.output_to_processing_mapping[i])); + partial_final_shape.set_expression( + i, input_shape.get_expression( + shape_spec.output_to_processing_mapping[i])); } } @@ -97,6 +100,8 @@ class StridedSliceOp : public XlaOpKernel { // Use input shape to update unknown dimension of partial shape -- if a // dimension is unknown, we use input shape as bound. partial_processing_shape.set_dim(i, input_shape.dim_size(i)); + partial_processing_shape.set_expression(i, + input_shape.get_expression(i)); } } TensorShape processing_shape; @@ -157,7 +162,8 @@ class StridedSliceOp : public XlaOpKernel { auto zero = xla::Zero(ctx->builder(), ctx->InputXlaType("begin")); xla::XlaOp begin_index, end_index; int64_t sparse_index = shape_spec.processing_to_sparse_mapping[i]; - bool xla_input_is_dynamic = input_xla_shape.is_dynamic_dimension(i); + bool xla_input_is_dynamic = input_xla_shape.is_dynamic_dimension(i) || + input_xla_shape.expressions(i)->is_dynamic(); xla::XlaOp dim_size; if (xla_input_is_dynamic) { dim_size = xla::GetDimensionSize(ctx->Input(0), i); @@ -215,10 +221,12 @@ class StridedSliceOp : public XlaOpKernel { } slice = - xla::DynamicSlice(slice, start_indices, processing_shape.dim_sizes()); + xla::DynamicSlice(slice, start_indices, processing_shape.dim_sizes(), + processing_shape.get_expressions()); // new_axis_mask_, ellipsis_mask_ and shrink_axis_mask_ may add or remove // size 1 dims of a shape. - slice = xla::Reshape(slice, final_shape.dim_sizes()); + slice = xla::Reshape(slice, final_shape.dim_sizes(), + final_shape.get_expressions()); for (int64_t i = 0; i < final_shape.dims(); ++i) { int64 processing_shape_dim = shape_spec.output_to_processing_mapping[i]; // If processing_shape_dim is -1, it means the output dimension was newly @@ -246,6 +254,8 @@ class StridedSliceOp : public XlaOpKernel { absl::InlinedVector begin; absl::InlinedVector end; + absl::InlinedVector begin_expr; + absl::InlinedVector end_expr; absl::InlinedVector strides; xla::Literal begin_literal, end_literal, strides_literal; @@ -268,14 +278,15 @@ class StridedSliceOp : public XlaOpKernel { PartialTensorShape partial_processing_shape, partial_final_shape; bool dummy = false; StridedSliceShapeSpec shape_spec; + OP_REQUIRES_OK( - ctx, - ValidateStridedSliceOp( - begin_is_constant ? &begin_tensor : nullptr, - end_is_constant ? &end_tensor : nullptr, strides_tensor, - input_shape, begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_, - shrink_axis_mask_, &partial_processing_shape, &partial_final_shape, - &dummy, &dummy, &dummy, &begin, &end, &strides, &shape_spec)); + ctx, ValidateStridedSliceOp( + begin_is_constant ? &begin_tensor : nullptr, + end_is_constant ? &end_tensor : nullptr, strides_tensor, + input_shape, begin_mask_, end_mask_, ellipsis_mask_, + new_axis_mask_, shrink_axis_mask_, &partial_processing_shape, + &partial_final_shape, &dummy, &dummy, &dummy, &begin, &end, + &strides, &begin_expr, &end_expr, &shape_spec)); xla::XlaOp slice = ctx->Input(0); std::vector begins_are_dynamic; @@ -294,17 +305,28 @@ class StridedSliceOp : public XlaOpKernel { ", output shape must be a compile-time constant")); absl::InlinedVector dimensions_to_reverse; absl::InlinedVector slice_begin, slice_end, slice_strides; + absl::InlinedVector slice_begin_expr, slice_end_expr; for (int i = 0; i < begin.size(); ++i) { if (strides[i] > 0) { slice_begin.push_back(begin[i]); + slice_begin_expr.push_back(begin_expr[i]); slice_end.push_back(std::max(end[i], begin[i])); + slice_end_expr.push_back((end[i] > begin[i]) ? end_expr[i] + : begin_expr[i]); slice_strides.push_back(strides[i]); } else { // Negative stride: swap begin and end, add 1 because the interval // is semi-open, and mark the dimension to be reversed. + auto input_exprs = input_shape.get_expressions(); slice_begin.push_back(input_shape.dim_size(i) - begin[i] - 1); + slice_begin_expr.push_back( + (*input_exprs[i] - *begin_expr[i] - 1)->s()); slice_end.push_back(std::max(input_shape.dim_size(i) - end[i] - 1, input_shape.dim_size(i) - begin[i] - 1)); + slice_end_expr.push_back( + (end[i] < begin[i]) + ? (*input_exprs[i] - *end_expr[i] - 1)->s() + : (*input_exprs[i] - *begin_expr[i] - 1)->s()); slice_strides.push_back(-strides[i]); dimensions_to_reverse.push_back(i); } @@ -312,7 +334,8 @@ class StridedSliceOp : public XlaOpKernel { if (!dimensions_to_reverse.empty()) { slice = xla::Rev(slice, dimensions_to_reverse); } - slice = xla::Slice(slice, slice_begin, slice_end, slice_strides); + slice = xla::Slice(slice, slice_begin, slice_end, slice_begin_expr, + slice_end_expr, slice_strides); auto operand_shape_or = ctx->builder()->GetShape(ctx->Input(0)); OP_REQUIRES_OK(ctx, operand_shape_or.status()); xla::Shape xla_shape = operand_shape_or.value(); @@ -325,7 +348,8 @@ class StridedSliceOp : public XlaOpKernel { bool ends_are_static = absl::c_all_of( ends_are_dynamic, [](bool dynamic) { return !dynamic; }); // Static output shape, return a static slice. - slice = xla::Reshape(slice, final_shape.dim_sizes()); + slice = xla::Reshape(slice, final_shape.dim_sizes(), + final_shape.get_expressions()); if (xla_shape.is_static() && ends_are_static) { ctx->SetOutput(0, slice); return; @@ -436,6 +460,8 @@ class StridedSliceGradOp : public XlaOpKernel { PartialTensorShape processing_shape, final_shape; absl::InlinedVector begin; absl::InlinedVector end; + absl::InlinedVector begin_expr; + absl::InlinedVector end_expr; absl::InlinedVector strides; StridedSliceShapeSpec shape_spec; OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_, @@ -445,7 +471,7 @@ class StridedSliceGradOp : public XlaOpKernel { nullptr, nullptr, strides_tensor, input_shape, begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_, shrink_axis_mask_, &processing_shape, &final_shape, &dummy, &dummy, &dummy, - &begin, &end, &strides, &shape_spec)); + &begin, &end, &strides, &begin_expr, &end_expr, &shape_spec)); for (int64_t i = 0; i < processing_shape.dims(); ++i) { OP_REQUIRES( ctx, strides[i] == 1, @@ -459,14 +485,17 @@ class StridedSliceGradOp : public XlaOpKernel { VLOG(1) << "xla final_shape" << final_shape; VLOG(1) << "input_shape" << input_shape.DebugString(); auto input_sizes = input_shape.dim_sizes(); + auto input_exprs = input_shape.get_expressions(); // For unknown output dim the bound of the output shape is input. Pad and // double the size of input shape to leave enough buffer to avoid OOB // dynamic update slice. auto input_sizes_padded = input_shape.dim_sizes(); + auto input_exprs_padded = input_shape.get_expressions(); bool need_padding = false; for (int64_t i = 0; i < processing_shape.dims(); ++i) { if (processing_shape.dim_size(i) == -1) { input_sizes_padded[i] *= 2; + input_exprs_padded[i] = (2 * *input_exprs_padded[i])->s(); need_padding = true; } } @@ -477,6 +506,7 @@ class StridedSliceGradOp : public XlaOpKernel { if (shape_spec.output_to_processing_mapping[i] != -1) { processing_shape.set_dim(shape_spec.output_to_processing_mapping[i], grad_shape.dimensions(i)); + // TODO Pass it back } } @@ -506,15 +536,19 @@ class StridedSliceGradOp : public XlaOpKernel { } auto zero = XlaHelpers::Zero(ctx->builder(), ctx->expected_output_dtype(0)); - zero = xla::Broadcast(zero, input_sizes_padded); - grad = xla::Reshape(grad, processing_shape.dim_sizes()); + zero = xla::Broadcast(zero, input_sizes_padded, input_exprs_padded); + grad = xla::Reshape(grad, processing_shape.dim_sizes(), + processing_shape.get_expressions()); grad = xla::DynamicUpdateSlice(zero, grad, begins); if (need_padding) { // We padded the input shape to avoid OOB when DUS. Now slice out the // padding in the final result. std::vector strides(input_shape.dims(), 1); std::vector start_indices(input_shape.dims(), 0); - grad = xla::Slice(grad, start_indices, input_sizes, strides); + std::vector start_exprs(input_shape.dims(), + xla::DynExpr::zero); + grad = xla::Slice(grad, start_indices, input_sizes, start_exprs, + input_exprs, strides); } ctx->SetOutput(0, grad); } @@ -522,6 +556,8 @@ class StridedSliceGradOp : public XlaOpKernel { TensorShape processing_shape, final_shape; absl::InlinedVector begin; absl::InlinedVector end; + absl::InlinedVector begin_expr; + absl::InlinedVector end_expr; absl::InlinedVector strides; TensorShape input_shape; @@ -547,11 +583,12 @@ class StridedSliceGradOp : public XlaOpKernel { bool dummy = false; OP_REQUIRES_OK( - ctx, ValidateStridedSliceOp( - &begin_tensor, &end_tensor, strides_tensor, input_shape, - begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_, - shrink_axis_mask_, &processing_shape, &final_shape, &dummy, - &dummy, &dummy, &begin, &end, &strides)); + ctx, + ValidateStridedSliceOp( + &begin_tensor, &end_tensor, strides_tensor, input_shape, + begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_, + shrink_axis_mask_, &processing_shape, &final_shape, &dummy, &dummy, + &dummy, &begin, &end, &strides, &begin_expr, &end_expr)); // Check to make sure dy is consistent with the original slice const TensorShape dy_shape = ctx->InputShape(4); @@ -570,7 +607,8 @@ class StridedSliceGradOp : public XlaOpKernel { xla::XlaOp grad = ctx->Input(4); // Undo any new/shrink axes. - grad = xla::Reshape(grad, processing_shape.dim_sizes()); + grad = xla::Reshape(grad, processing_shape.dim_sizes(), + processing_shape.get_expressions()); // Pad the input gradients. absl::InlinedVector dimensions_to_reverse; @@ -662,6 +700,8 @@ class StridedSliceAssignOp : public XlaOpKernel { TensorShape final_shape; absl::InlinedVector begin; absl::InlinedVector end; + absl::InlinedVector begin_expr; + absl::InlinedVector end_expr; absl::InlinedVector strides; xla::Literal begin_literal, end_literal, strides_literal; @@ -690,12 +730,13 @@ class StridedSliceAssignOp : public XlaOpKernel { TensorShape dummy_processing_shape; bool dummy = false; - OP_REQUIRES_OK(ctx, - ValidateStridedSliceOp( - &begin_tensor, &end_tensor, strides_tensor, lhs_shape, - begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_, - shrink_axis_mask_, &dummy_processing_shape, &final_shape, - &dummy, &dummy, &dummy, &begin, &end, &strides)); + OP_REQUIRES_OK( + ctx, + ValidateStridedSliceOp( + &begin_tensor, &end_tensor, strides_tensor, lhs_shape, begin_mask_, + end_mask_, ellipsis_mask_, new_axis_mask_, shrink_axis_mask_, + &dummy_processing_shape, &final_shape, &dummy, &dummy, &dummy, + &begin, &end, &strides, &begin_expr, &end_expr)); if (final_shape.num_elements() == 0 && rhs_shape.num_elements() == 0) { // DynamicUpdateSlice does not allow 0-element updates. We should probably @@ -717,6 +758,7 @@ class StridedSliceAssignOp : public XlaOpKernel { absl::InlinedVector dimensions_to_reverse; absl::InlinedVector slice_begin; absl::InlinedVector slice_dims; + absl::InlinedVector slice_exprs; for (int i = 0; i < begin.size(); ++i) { // TODO(b/121179231): implement strides != 1 OP_REQUIRES( @@ -726,12 +768,14 @@ class StridedSliceAssignOp : public XlaOpKernel { slice_begin.push_back( xla::ConstantR0(ctx->builder(), begin[i])); slice_dims.push_back(end[i] - begin[i]); + slice_exprs.push_back(xla::DynExpr::_(end[i] - begin[i])); } else { // Negative stride: swap begin and end, add 1 because the interval // is semi-open, and mark the dimension to be reversed. slice_begin.push_back( xla::ConstantR0(ctx->builder(), end[i] + 1)); slice_dims.push_back(begin[i] - end[i]); + slice_exprs.push_back(xla::DynExpr::_(begin[i] - end[i])); dimensions_to_reverse.push_back(i); } } @@ -739,7 +783,7 @@ class StridedSliceAssignOp : public XlaOpKernel { if (!dimensions_to_reverse.empty()) { rhs = xla::Rev(rhs, dimensions_to_reverse); } - rhs = xla::Reshape(rhs, slice_dims); + rhs = xla::Reshape(rhs, slice_dims, slice_exprs); lhs = xla::DynamicUpdateSlice(lhs, rhs, slice_begin); diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 888908e30b2331..2fb7f134283f2c 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -167,7 +167,8 @@ class TensorArrayOp : public XlaOpKernel { ta_shape.AddDim(size); ta_shape.AppendShape(shape); xla::XlaOp zero = XlaHelpers::Zero(b, dtype_); - value = xla::Broadcast(zero, ta_shape.dim_sizes()); + value = xla::Broadcast(zero, ta_shape.dim_sizes(), + ta_shape.get_expressions()); } XlaResource* var = @@ -223,7 +224,8 @@ class TensorArrayWriteOp : public XlaOpKernel { TensorShape slice_shape = elem_shape; slice_shape.InsertDim(0, 1LL); - auto update = xla::Reshape(value, slice_shape.dim_sizes()); + auto update = xla::Reshape(value, slice_shape.dim_sizes(), + slice_shape.get_expressions()); xla::XlaOp written; if (resource->tensor_array_multiple_writes_aggregate()) { @@ -274,9 +276,12 @@ class TensorArrayReadOp : public XlaOpKernel { start_indices[0] = index; auto slice_shape = ta_shape.dim_sizes(); + auto slice_exprs = ta_shape.get_expressions(); slice_shape[0] = 1LL; + slice_exprs[0] = xla::DynExpr::_(1LL); - xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape); + xla::XlaOp read = + xla::DynamicSlice(ta, start_indices, slice_shape, slice_exprs); // Remove the leading '1' dimension. std::vector value_shape(slice_shape.begin() + 1, @@ -469,9 +474,12 @@ class TensorArrayConcatOp : public XlaOpKernel { xla::XlaOp ta = resource->value(); auto ta_dims = ta_shape.dim_sizes(); + auto ta_exprs = ta_shape.get_expressions(); std::vector shape(ta_dims.begin() + 1, ta_dims.end()); + std::vector exprs(ta_exprs.begin() + 1, ta_exprs.end()); shape[0] *= ta_shape.dim_size(0); - ctx->SetOutput(0, xla::Reshape(ta, shape)); + exprs[0] = *ta_exprs[0] * *ta_shape.get_expression(0); + ctx->SetOutput(0, xla::Reshape(ta, shape, exprs)); Tensor lengths(DT_INT64, {ta_dims[0]}); auto lengths_vec = lengths.vec(); @@ -526,6 +534,7 @@ class TensorArraySplitOp : public XlaOpKernel { TensorShape ta_shape; ta_shape.AddDim(resource->max_array_size()); + ta_shape.AddExpression(xla::DynExpr::_(resource->max_array_size())); ta_shape.AppendShape(elem_shape); OP_REQUIRES(ctx, lengths.size() == resource->max_array_size(), @@ -541,7 +550,8 @@ class TensorArraySplitOp : public XlaOpKernel { value_shape.DebugString(), " vs. ", ta_shape.DebugString())); - const xla::XlaOp reshape = xla::Reshape(value, ta_shape.dim_sizes()); + const xla::XlaOp reshape = + xla::Reshape(value, ta_shape.dim_sizes(), ta_shape.get_expressions()); if (dtype_ == DT_BOOL) { ta = xla::Or(ta, reshape); } else { diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc index a1f58d5ae9b40e..ae0f27a4fad59f 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -131,7 +131,8 @@ absl::Status TryGetElementShapeFromInput(XlaOpKernelContext* ctx, return absl::OkStatus(); } - *shape = xla::ShapeUtil::MakeShape(dtype, partial_shape.dim_sizes()); + *shape = xla::ShapeUtil::MakeShape(dtype, partial_shape.dim_sizes(), + partial_shape.get_expressions()); *got_shape = true; return absl::OkStatus(); } @@ -503,6 +504,8 @@ class TensorListConcatOp : public XlaOpKernel { xla::Shape element_shape = std::move(shape_or).value(); std::vector element_dims = xla::SpanToVector(element_shape.dimensions()); + std::vector element_exprs = + xla::SpanToVector(element_shape.expressions()); OP_REQUIRES( ctx, element_dims.size() > 1, errors::Unimplemented("TensorList of scalars is not supported")); @@ -510,12 +513,15 @@ class TensorListConcatOp : public XlaOpKernel { int64_t tensor_lengths = element_dims[1]; std::vector new_dims = {num_elements * tensor_lengths}; + std::vector new_exprs = { + xla::DynExpr::_(num_elements * tensor_lengths)}; for (int i = 2; i < element_dims.size(); i++) { new_dims.push_back(element_dims[i]); + new_exprs.push_back(element_exprs[i]); } - xla::XlaOp out = xla::Reshape(buffer, new_dims); + xla::XlaOp out = xla::Reshape(buffer, new_dims, new_exprs); ctx->SetOutput(0, out); // Second output is a tensor of lengths of returned tensors. @@ -550,6 +556,8 @@ class TensorListSplitOp : public XlaOpKernel { xla::Shape element_shape = std::move(shape_or).value(); std::vector element_dims = xla::SpanToVector(element_shape.dimensions()); + std::vector element_exprs = + xla::SpanToVector(element_shape.expressions()); OP_REQUIRES( ctx, !element_dims.empty(), errors::Unimplemented("Element dimensions have to be non-empty")); @@ -569,11 +577,13 @@ class TensorListSplitOp : public XlaOpKernel { ctx, element_dims[0] % length == 0, errors::Unimplemented("Buffer size has to be a multiple of length")); std::vector new_dims = {element_dims[0] / length, length}; + std::vector new_exprs = {*element_exprs[0] / length, + xla::DynExpr::_(length)}; for (int i = 1; i < element_dims.size(); i++) { new_dims.push_back(element_dims[i]); } - xla::XlaOp reshaped = xla::Reshape(input_tensor, new_dims); + xla::XlaOp reshaped = xla::Reshape(input_tensor, new_dims, new_exprs); xla::XlaOp result; OP_REQUIRES_OK(ctx, ExecuteTensorListFromTensor(length, reshaped, &result)); diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc index 683dc4737e6dab..1771181e440f31 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc @@ -389,7 +389,12 @@ absl::Status ExecuteTensorListPushBack(xla::XlaOp list, xla::XlaOp element, std::vector element_part_dims = xla::SpanToVector(element_part_shape.dimensions()); element_part_dims.insert(element_part_dims.begin(), 1); - element_part = xla::Reshape(element_part, element_part_dims); + std::vector element_part_exprs = + xla::SpanToVector(element_part_shape.expressions()); + element_part_exprs.insert(element_part_exprs.begin(), + xla::DynExpr::one); + element_part = + xla::Reshape(element_part, element_part_dims, element_part_exprs); std::vector start_indices( element_part_shape.dimensions().size() + 1, @@ -406,7 +411,10 @@ absl::Status ExecuteTensorListPushBack(xla::XlaOp list, xla::XlaOp element, std::vector element_dims = xla::SpanToVector(element_shape.dimensions()); element_dims.insert(element_dims.begin(), 1); - xla::XlaOp update = xla::Reshape(element, element_dims); + std::vector element_exprs = + xla::SpanToVector(element_shape.expressions()); + element_exprs.insert(element_exprs.begin(), xla::DynExpr::one); + xla::XlaOp update = xla::Reshape(element, element_dims, element_exprs); std::vector start_indices(element_shape.dimensions().size() + 1, xla::ConstantR0(b, 0)); @@ -455,11 +463,16 @@ absl::Status ExecuteTensorListPopBack(xla::XlaOp list, xla::XlaOp* list_result, xla::SpanToVector(list_part_shape.dimensions()); slice_shape[0] = 1LL; + std::vector slice_exprs = + xla::SpanToVector(list_part_shape.expressions()); + slice_exprs[0] = xla::DynExpr::_(1LL); + xla::XlaOp list_part = xla::GetTupleElement(list, i); xla::XlaOp read = xla::DynamicSlice(list_part, start_indices, slice_shape); slice_shape.erase(slice_shape.begin()); - element_result_parts.push_back(xla::Reshape(read, slice_shape)); + element_result_parts.push_back( + xla::Reshape(read, slice_shape, slice_exprs)); list_result_parts.push_back(list_part); } list_result_parts.push_back(push_index); @@ -493,7 +506,11 @@ absl::Status ExecuteTensorListSetItem(xla::XlaOp list, xla::XlaOp index, std::vector element_dims = xla::SpanToVector(element_shape.dimensions()); element_dims.insert(element_dims.begin(), 1); - xla::XlaOp update = xla::Reshape(element, element_dims); + std::vector element_exprs = + xla::SpanToVector(element_shape.expressions()); + element_exprs.insert(element_exprs.begin(), xla::DynExpr::one); + + xla::XlaOp update = xla::Reshape(element, element_dims, element_exprs); std::vector start_indices(element_shape.dimensions().size() + 1, xla::ConstantR0(b, 0)); @@ -557,6 +574,10 @@ absl::Status ExecuteTensorListGetItem(xla::XlaOp list, xla::XlaOp index, xla::SpanToVector(buffer_shape.dimensions()); slice_shape[0] = 1LL; + std::vector slice_exprs = + xla::SpanToVector(buffer_shape.expressions()); + slice_exprs[0] = xla::DynExpr::_(1LL); + xla::XlaOp list_part = xla::GetTupleElement(list, 0); xla::XlaOp read = xla::DynamicSlice(list_part, start_indices, slice_shape); // Propagate dynamic dimensions from buffer to the sliced buffer, except for @@ -569,7 +590,8 @@ absl::Status ExecuteTensorListGetItem(xla::XlaOp list, xla::XlaOp index, } } slice_shape.erase(slice_shape.begin()); - *result = xla::Reshape(read, slice_shape); + slice_exprs.erase(slice_exprs.begin()); + *result = xla::Reshape(read, slice_shape, slice_exprs); return absl::OkStatus(); } diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc index 6c39981ba5b937..3eab14bf78e968 100644 --- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc @@ -67,11 +67,17 @@ class TileOp : public XlaOpKernel { xla::ValueInferenceMode::kUpperBound)); std::vector output_dims(input_shape.dims()); + std::vector output_exprs(input_shape.dims()); + + auto expr_sizes = input_shape.get_expressions(); + for (int64_t i = 0; i < input_shape.dims(); ++i) { OP_REQUIRES(ctx, multiples_bounds[i] >= 0, errors::InvalidArgument("Expected multiples[", i, "] >= 0, but got ", output_dims[i])); output_dims[i] = input_shape.dim_size(i) * multiples_bounds[i]; + output_exprs[i] = + (*expr_sizes[i] * *xla::DynExpr::_(multiples_bounds[i]))->s(); } std::vector multiples_are_dynamic; @@ -91,8 +97,8 @@ class TileOp : public XlaOpKernel { return; } } - - auto result_or = BroadcastTo(ctx->Input("input"), output_dims); + auto result_or = + BroadcastTo(ctx->Input("input"), output_dims, output_exprs); OP_REQUIRES_OK(ctx, result_or.status()); auto result = result_or.value(); diff --git a/tensorflow/compiler/tf2xla/kernels/unique_op.cc b/tensorflow/compiler/tf2xla/kernels/unique_op.cc index 46de3dd89b6115..fbe181d13ef547 100644 --- a/tensorflow/compiler/tf2xla/kernels/unique_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/unique_op.cc @@ -83,9 +83,12 @@ class UniqueOpBase : public XlaOpKernel { // // This is implemented as an hlo while loop. xla::XlaOp RollingSelectR1(XlaOpKernelContext* ctx, xla::XlaOp data, - xla::XlaOp mask, int64_t size) { + xla::XlaOp mask, int64_t size, + xla::DynExpr* expr) { xla::XlaComputation cond, body; - const xla::Shape r1_shape = xla::ShapeUtil::MakeShape(xla::S32, {size}); + xla::Shape r1_shape = xla::ShapeUtil::MakeShape(xla::S32, {size}); + r1_shape.set_expression(0, expr); + const xla::Shape counter_shape = xla::ShapeUtil::MakeScalarShape(xla::S32); const xla::Shape& single_element_shape = counter_shape; @@ -136,7 +139,7 @@ class UniqueOpBase : public XlaOpKernel { } auto zero = xla::Zero(ctx->builder(), xla::S32); - auto zero_broadcast = xla::Broadcast(zero, {size}); + auto zero_broadcast = xla::Broadcast(zero, {size}, {expr}); auto init = xla::Tuple(ctx->builder(), {zero, data, mask, zero_broadcast}); return xla::GetTupleElement(xla::While(cond, body, init), 3); } @@ -153,13 +156,19 @@ class UniqueOpBase : public XlaOpKernel { auto aux = MoveAxis(input, axis, 0, input_shape); auto aux_shape = ctx->builder()->GetShape(aux).value(); int64_t leading_size = aux_shape.dimensions(0); + auto leading_expr = aux_shape.expressions(0); int64_t product = 1; + auto product_expr = xla::DynExpr::one; for (int64_t i = 1; i < aux_shape.dimensions().size(); ++i) { product *= aux_shape.dimensions(i); + product_expr = *(product_expr) * *(aux_shape.expressions(i)); } - aux = xla::Reshape(aux, {leading_size, product}); + product_expr = product_expr->s(); + aux = xla::Reshape(aux, {leading_size, product}, + {leading_expr, product_expr}); if (leading_size == 0) { - auto result_data = xla::Reshape(aux, aux_shape.dimensions()); + auto result_data = + xla::Reshape(aux, aux_shape.dimensions(), aux_shape.expressions()); result_data = MoveAxis(result_data, 0, axis, aux_shape); ctx->SetOutput(0, result_data); ctx->SetOutput(1, xla::Iota(ctx->builder(), xla::S32, leading_size)); @@ -171,10 +180,13 @@ class UniqueOpBase : public XlaOpKernel { sort_types.reserve(product + 1); for (int64_t i = 0; i < product; ++i) { xla::XlaOp slice = xla::SliceInDim(aux, i, i + 1, 1, 1); - sort_keys.push_back(xla::Reshape(slice, {leading_size})); + sort_keys.push_back(xla::Reshape(slice, {leading_size}, {leading_expr})); sort_types.push_back(input_shape.element_type()); } - auto iota = xla::Iota(ctx->builder(), xla::S32, leading_size); + xla::Shape iota_shape = + xla::ShapeUtil::MakeShape(xla::S32, {leading_size}, {leading_expr}); + iota_shape.set_expression(0, leading_expr); + auto iota = xla::Iota(ctx->builder(), iota_shape, 0); sort_keys.push_back(iota); sort_types.push_back(xla::S32); @@ -202,16 +214,18 @@ class UniqueOpBase : public XlaOpKernel { gather_dim_numbers.add_collapsed_slice_dims(0); auto permuted = xla::Gather(aux, perm, gather_dim_numbers, {1, product}); // Tail is everything except for first element. - auto tail = xla::SliceInDim(permuted, 1, leading_size, 1, 0); + auto tail = xla::SliceInDim(permuted, 1, leading_size, + xla::DynExpr::one, leading_expr, 1, 0); // Init is everything except for last element. - auto init = xla::SliceInDim(permuted, 0, leading_size - 1, 1, 0); + auto init = xla::SliceInDim(permuted, 0, leading_size - 1, + xla::DynExpr::zero, *leading_expr - 1, 1, 0); auto ne = xla::Compare(tail, init, xla::ComparisonDirection::kNe); auto reduce = xla::Reduce(ne, xla::ConstantR0(ctx->builder(), false), CreateScalarOrComputation(xla::PRED, ctx->builder()), {1}); auto mask = xla::ConvertElementType(reduce, xla::S32); mask = xla::PadInDim(mask, xla::One(ctx->builder(), xla::S32), 0, 1, 0); - auto iperm = RollingSelectR1(ctx, perm, mask, leading_size); + auto iperm = RollingSelectR1(ctx, perm, mask, leading_size, leading_expr); auto sort_by_iperm = xla::Sort({iperm, mask, perm}, @@ -232,12 +246,15 @@ class UniqueOpBase : public XlaOpKernel { /*is_stable=*/true); auto mask_permute = xla::GetTupleElement(mask_sort, 1); permuted = xla::Gather(aux, mask_permute, gather_dim_numbers, {1, product}); - auto result_data = xla::Reshape(permuted, aux_shape.dimensions()); + auto result_data = + xla::Reshape(permuted, aux_shape.dimensions(), aux_shape.expressions()); result_data = MoveAxis(result_data, 0, axis, aux_shape); result_data = xla::SetDimensionSize(result_data, dynamic_size, axis); ctx->SetOutput(0, result_data); auto imask = CumSumR1(ctx, mask, leading_size); - imask = xla::Sub(imask, xla::One(ctx->builder(), xla::S32), {}); + auto one = xla::One(ctx->builder(), xla::S32); + auto one_broadcast = xla::Broadcast(one, {leading_size}, {leading_expr}); + imask = xla::Sub(imask, one_broadcast, {}); auto idx = xla::GetTupleElement( xla::Sort({perm_sort, imask}, xla::CreateScalarLtComputation({xla::S32, xla::S32}, diff --git a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc index cca29f7f585907..d83c77e1b5a68f 100644 --- a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc @@ -69,7 +69,8 @@ class UnpackOp : public XlaOpKernel { limit_indices[axis] = i + 1; auto slice = xla::Slice(input, start_indices, limit_indices, strides); // Reshape to drop the 'axis' dimension. - auto result = xla::Reshape(slice, output_shape.dim_sizes()); + auto result = xla::Reshape(slice, output_shape.dim_sizes(), + output_shape.get_expressions()); ctx->SetOutput(i, result); } } diff --git a/tensorflow/compiler/tf2xla/kernels/where_op.cc b/tensorflow/compiler/tf2xla/kernels/where_op.cc index f97e6d5077efa7..29bcf4fa0769ad 100644 --- a/tensorflow/compiler/tf2xla/kernels/where_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/where_op.cc @@ -165,7 +165,13 @@ absl::StatusOr CompileWhereWithSort(XlaOpKernelContext* ctx) { xla::ShapeUtil::MakeShape(xla::S32, input_shape.dimensions()); int64_t flattened_size = xla::Product(iota_shape.dimensions()); - XlaOp reshaped_condition = xla::Reshape(condition, {flattened_size}); + xla::DynExpr* flattened_expr = xla::DynExpr::one; + for (auto e : iota_shape.expressions()){ + flattened_expr = *flattened_expr * *e; + } + flattened_expr = flattened_expr->s(); + XlaOp reshaped_condition = + xla::Reshape(condition, {flattened_size}, {flattened_expr}); XlaOp zeros = xla::ZerosLike(reshaped_condition); XlaOp compared = xla::Ne(reshaped_condition, zeros); @@ -175,7 +181,7 @@ absl::StatusOr CompileWhereWithSort(XlaOpKernelContext* ctx) { // indices of each element. for (int64_t axis = 0; axis < iota_shape.dimensions_size(); ++axis) { XlaOp iota = xla::Iota(ctx->builder(), iota_shape, axis); - XlaOp reshaped = xla::Reshape(iota, {flattened_size}); + XlaOp reshaped = xla::Reshape(iota, {flattened_size}, {flattened_expr}); to_sort.push_back(reshaped); types_to_sort.push_back(xla::S32); } @@ -186,7 +192,8 @@ absl::StatusOr CompileWhereWithSort(XlaOpKernelContext* ctx) { std::vector to_concat; for (int64_t i = 0; i < iota_shape.dimensions_size(); ++i) { XlaOp index_single_dim = xla::GetTupleElement(sorted, i + 1); - to_concat.push_back(xla::Reshape(index_single_dim, {flattened_size, 1})); + to_concat.push_back(xla::Reshape(index_single_dim, {flattened_size, 1}, + {flattened_expr, xla::DynExpr::one})); } XlaOp result = xla::ConcatInDim(ctx->builder(), to_concat, 1); @@ -214,7 +221,13 @@ absl::StatusOr CompileWhereWithPrefixSum(XlaOpKernelContext* ctx) { TF_ASSIGN_OR_RETURN(xla::Shape input_shape, b->GetShape(condition)); int64_t flattened_size = xla::Product(input_shape.dimensions()); - XlaOp reshaped_condition = xla::Reshape(condition, {flattened_size}); + xla::DynExpr* flattened_expr = xla::DynExpr::one; + for (auto e : input_shape.expressions()) { + flattened_expr = *flattened_expr * *e; + } + flattened_expr = flattened_expr->s(); + XlaOp reshaped_condition = + xla::Reshape(condition, {flattened_size}, {flattened_expr}); XlaOp zeros = xla::ZerosLike(reshaped_condition); XlaOp preds = xla::ConvertElementType(xla::Ne(reshaped_condition, zeros), S32); @@ -253,7 +266,8 @@ absl::StatusOr CompileWhereWithPrefixSum(XlaOpKernelContext* ctx) { XlaOp out_idxs = xla::Select(xla::Ne(prefix_sum, prefix_sum_shifted), /*on_true=*/prefix_sum - xla::One(b, S32), /*on_false=*/oob_idx); - out_idxs = xla::Reshape(out_idxs, {flattened_size, 1}); + out_idxs = xla::Reshape(out_idxs, {flattened_size, 1}, + {flattened_expr, xla::DynExpr::one}); // tf.where returns an array of multidimensional indices where the condition // is true. For example: @@ -280,7 +294,8 @@ absl::StatusOr CompileWhereWithPrefixSum(XlaOpKernelContext* ctx) { iotas_to_concat.reserve(iota_shape.dimensions_size()); for (int64_t axis = 0; axis < iota_shape.dimensions_size(); ++axis) { iotas_to_concat.push_back( - xla::Reshape(xla::Iota(b, iota_shape, axis), {flattened_size, 1})); + xla::Reshape(xla::Iota(b, iota_shape, axis), {flattened_size, 1}, + {flattened_expr, xla::DynExpr::one})); } XlaOp iotas = xla::ConcatInDim(b, iotas_to_concat, /*dimension=*/1); diff --git a/tensorflow/compiler/tf2xla/layout_util.cc b/tensorflow/compiler/tf2xla/layout_util.cc index b000c49f1f962e..fec1ce280cf922 100644 --- a/tensorflow/compiler/tf2xla/layout_util.cc +++ b/tensorflow/compiler/tf2xla/layout_util.cc @@ -133,6 +133,7 @@ absl::StatusOr ReshapeWithCorrectRepresentationAndSharding( if (xla::ShapeUtil::Compatible(original_shape, to_shape)) { for (int64_t i = 0; i < original_shape.dimensions().size(); ++i) { to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i)); + to_shape.set_expression(i, original_shape.expressions(i)); } } return xla::Reshape(to_shape, original); diff --git a/tensorflow/compiler/tf2xla/lib/broadcast.cc b/tensorflow/compiler/tf2xla/lib/broadcast.cc index f815b91d04be33..c866c4429d4818 100644 --- a/tensorflow/compiler/tf2xla/lib/broadcast.cc +++ b/tensorflow/compiler/tf2xla/lib/broadcast.cc @@ -32,9 +32,10 @@ limitations under the License. namespace tensorflow { -absl::StatusOr BroadcastTo(xla::XlaOp input, - absl::Span output_dims) { - return xla::BroadcastTo(input, output_dims); +absl::StatusOr BroadcastTo( + xla::XlaOp input, absl::Span output_dims, + absl::Span output_exprs) { + return xla::BroadcastTo(input, output_dims, output_exprs); } absl::Status BroadcastOpsToSame(xla::XlaOp* lhs, xla::XlaOp* rhs) { diff --git a/tensorflow/compiler/tf2xla/lib/broadcast.h b/tensorflow/compiler/tf2xla/lib/broadcast.h index 60630971e27466..ee56975b664f7b 100644 --- a/tensorflow/compiler/tf2xla/lib/broadcast.h +++ b/tensorflow/compiler/tf2xla/lib/broadcast.h @@ -29,8 +29,9 @@ namespace tensorflow { // Forwards to xla::BroadcastTo. // TODO(cheshire): Call the underlying function directly. -absl::StatusOr BroadcastTo(xla::XlaOp input, - absl::Span output_dims); +absl::StatusOr BroadcastTo( + xla::XlaOp input, absl::Span output_dims, + absl::Span output_exprs = {}); // Forwards to xla::BroadcastOpsToSame. absl::Status BroadcastOpsToSame(xla::XlaOp* lhs, xla::XlaOp* rhs); diff --git a/tensorflow/compiler/tf2xla/lib/data_format.cc b/tensorflow/compiler/tf2xla/lib/data_format.cc index 2473b97af4c2bd..38116fddb200af 100644 --- a/tensorflow/compiler/tf2xla/lib/data_format.cc +++ b/tensorflow/compiler/tf2xla/lib/data_format.cc @@ -53,7 +53,12 @@ absl::StatusOr Contract(xla::XlaOp input, int64_t dim) { input_shape.dimensions().end() - 1); contracted_shape[dim] *= 4; - return xla::Reshape(xla::Transpose(input, permutation), contracted_shape); + std::vector contracted_exprs( + input_shape.expressions().begin(), input_shape.expressions().end() - 1); + contracted_exprs[dim] = (*(contracted_exprs[dim]) * *xla::DynExpr::_(4))->s(); + + return xla::Reshape(xla::Transpose(input, permutation), contracted_shape, + contracted_exprs); } absl::StatusOr Expand(xla::XlaOp input, int64_t dim) { @@ -85,7 +90,7 @@ absl::StatusOr Expand(xla::XlaOp input, int64_t dim) { } permutation.push_back(dim + 1); - return xla::Transpose(xla::Reshape(input, expanded_shape), permutation); + return xla::Transpose(xla::Reshape(input, expanded_shape, {}), permutation); } } // namespace diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index 6a67cfa237af70..fa417215032662 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -1143,16 +1143,22 @@ xla::Shape GetShape(shape_inference::ShapeHandle shape_handle, } std::vector dims; std::vector dynamic_dims; + std::vector expressions; for (int i = 0, rank = c->Rank(shape_handle); i < rank; ++i) { bool is_dynamic = !c->ValueKnown(c->Dim(shape_handle, i)); + int dynamic_multiplier = c->DynamicRatio(c->Dim(shape_handle, i)); dynamic_dims.push_back(is_dynamic); + expressions.push_back(dynamic_multiplier * *xla::DynExpr::V(1)); dims.push_back(is_dynamic ? xla::Shape::kUnboundedSize : c->Value(c->Dim(shape_handle, i))); } - return xla::Shape( + xla::Shape sh( // Type matters only for indices. S64 is the widest possible type. xla::PrimitiveType::S64, dims, absl::InlinedVector(dynamic_dims.begin(), dynamic_dims.end())); + + sh.set_expressions(expressions); + return sh; } REGISTER_OP("XlaGather") diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc index 0d7549d81c20f6..f2dccaea7b1cac 100644 --- a/tensorflow/compiler/tf2xla/shape_util.cc +++ b/tensorflow/compiler/tf2xla/shape_util.cc @@ -100,6 +100,9 @@ absl::Status XLAShapeToTensorShape(const xla::Shape& shape, for (int i = 0; i < shape.dimensions().size(); ++i) { TF_RETURN_IF_ERROR(tensor_shape->AddDimWithStatus(shape.dimensions(i))); } + std::vector bexprs(shape.expressions().begin(), + shape.expressions().end()); + tensor_shape->set_expressions(bexprs); return absl::OkStatus(); } @@ -167,6 +170,7 @@ xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, } int rank = tensor_shape.dims(); std::vector dimensions(rank); + std::vector expressions(rank); std::vector layout(rank); for (int d = 0; d < rank; ++d) { dimensions[d] = tensor_shape.dim_size(d); @@ -175,11 +179,13 @@ xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, "shape; returning unknown sentinel value"; return xla::ShapeUtil::MakeShapeWithDenseLayout(type, {0}, {0}); } + expressions[d] = tensor_shape.get_expression(d); } // XLA uses minor-to-major; Tensorflow uses major-to-minor. std::iota(layout.rbegin(), layout.rend(), 0); xla::Shape result = xla::ShapeUtil::MakeShapeWithDenseLayout(type, dimensions, layout); + result.set_expressions(expressions); return result; } @@ -200,18 +206,29 @@ absl::StatusOr TensorShapeToXLAShape( return out; } +inline static int var_id = 1; + xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, const TensorShape& tensor_shape) { int rank = tensor_shape.dims(); std::vector dimensions(rank); std::vector layout(rank); + std::vector expressions(rank); + for (int d = 0; d < rank; ++d) { dimensions[d] = tensor_shape.dim_size(d); + expressions[d] = (d < tensor_shape.get_expressions().size()) + ? tensor_shape.get_expression(d) + : xla::DynExpr::_(dimensions[d]); } + // XLA uses minor-to-major; Tensorflow uses major-to-minor. std::iota(layout.rbegin(), layout.rend(), 0); - return xla::ShapeUtil::MakeShapeWithDenseLayout(type, dimensions, layout); + auto shape = + xla::ShapeUtil::MakeShapeWithDenseLayout(type, dimensions, layout); + shape.set_expressions(expressions); + return shape; } absl::StatusOr> GetShapeLayoutVector(const xla::Shape& shape) { diff --git a/tensorflow/compiler/tf2xla/xla_argument.h b/tensorflow/compiler/tf2xla/xla_argument.h index 86811d3416eb46..9873082c3f4d6e 100644 --- a/tensorflow/compiler/tf2xla/xla_argument.h +++ b/tensorflow/compiler/tf2xla/xla_argument.h @@ -53,9 +53,6 @@ struct XlaArgument { kTensorList, }; - //To keep dynamic dim as an attribute of the argument. - int64_t dynamic_dim = -1; - Kind kind = kInvalid; // The type of the argument. If the argument is a resource, this @@ -119,6 +116,7 @@ struct XlaArgument { // Returns the dimension sizes for either TensorShape or xla::Shape. std::vector DimensionSizes() const; + std::vector DimensionExpressions() const; absl::InlinedVector DimensionSizesAsInlinedVector() const; // Returns the human-readable string for either TensorShape or xla::Shape. diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 65df8e1fc15c73..84a5b1e5c05eca 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -511,6 +511,14 @@ std::vector XlaCompiler::Argument::DimensionSizes() const { } } +std::vector XlaCompiler::Argument::DimensionExpressions() const { + if (absl::holds_alternative(shape)) { + return std::get(shape).get_expressions(); + } else { + return xla::SpanToVector(std::get(shape).expressions()); + } +} + absl::InlinedVector XlaCompiler::Argument::DimensionSizesAsInlinedVector() const { if (absl::holds_alternative(shape)) { @@ -839,9 +847,9 @@ absl::Status XlaCompiler::CompileFunction( std::vector{tensor_shape}); } } else { + auto* val_ptr = std::get_if(&args[i].shape); TensorShape tensor_shape = std::get(args[i].shape); AttrSlice n_attrs = fbody->arg_nodes[i]->attrs(); - std::vector output_shapes; fbody->arg_nodes[i]->ClearAttr("_output_shapes"); fbody->arg_nodes[i]->AddAttr("_output_shapes", std::vector{tensor_shape}); @@ -1085,6 +1093,7 @@ absl::Status XlaCompiler::BuildArguments( TF_RET_CHECK(absl::holds_alternative(arg.shape)); // TODO(phawkins): this code assumes that resource arguments do not // alias. + auto* val_ptr = std::get_if(&arg.shape); XlaResource* resource = context->AddResource(std::make_unique( arg.resource_kind, i, arg.name, arg.type, @@ -1196,13 +1205,6 @@ absl::Status XlaCompiler::BuildArguments( xla::OpMetadata arg_metadata; arg_metadata.set_op_name(arg.node_name); - - if (arg.dynamic_dim==0) { - // Encode dynamic dims as a string in op_type, so it appears in HLO metadata. - arg_metadata.set_op_type( - absl::StrCat("XLA_Arg_dyn[",arg.dynamic_dim, "]")); - } - builder->SetOneShotOpMetadata(arg_metadata); arg_handles[i] = xla::GetTupleElement(tuple, i); } @@ -1215,12 +1217,7 @@ absl::Status XlaCompiler::BuildArguments( auto& arg = args[input_to_args->at(i)]; xla::OpMetadata arg_metadata; arg_metadata.set_op_name(arg.node_name); - if (arg.dynamic_dim==0) { - arg_metadata.set_op_type( - absl::StrCat("XLA_Arg_dyn[",arg.dynamic_dim, "]")); - } builder->SetOneShotOpMetadata(arg_metadata); - if (is_entry_computation) { // Add an entry to is_same_across_replicas for every leaf buffer. std::vector is_same_across_replicas( @@ -1240,10 +1237,9 @@ absl::Status XlaCompiler::BuildArguments( // Fill in the handles in non-constant arguments, and reshape parameters // back to their correct shapes. - VLOG(2) << "XLA computation inputs:"; for (std::vector::size_type i = 0; i < input_to_args->size(); ++i) { const XlaCompiler::Argument& arg = args[input_to_args->at(i)]; - VLOG(2) << " XLA arg " << i + VLOG(2) << " XLA arg " << i << " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i]) << " name: " << arg.name << " TF arg " << input_to_args->at(i) << " node name: " << arg.node_name @@ -1269,7 +1265,9 @@ absl::Status XlaCompiler::BuildArguments( // return values of functions, and then reshape unconditionally. if (is_entry_computation) { arg_expression = XlaExpression::XlaOp( - xla::Reshape(arg_handles[i], arg.DimensionSizes()), arg.type); + xla::Reshape(arg_handles[i], arg.DimensionSizes(), + arg.DimensionExpressions()), + arg.type); } else { arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type); if (arg.value_bound) { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 4a570827029330..e022b2fbec258b 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -736,7 +736,8 @@ absl::Status AssignVariableTensor(const Tensor& tensor, DataType type, xla::Shape xla_shape; TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape)); if (!xla::ShapeUtil::Compatible(xla_shape, representation_shape)) { - handle = xla::Reshape(handle, representation_shape.dimensions()); + handle = xla::Reshape(handle, representation_shape.dimensions(), + representation_shape.expressions()); } variable->SetRepresentationShape(representation_shape); return variable->SetValue(handle); diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 2f0ff5e91867f1..72bd6f712672e3 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1740,6 +1740,7 @@ tf_cuda_library( "@local_xla//xla/tsl/framework:cancellation", "@local_xla//xla/tsl/util:command_line_flags", "@local_xla//xla/tsl/util:device_name_utils", + "@local_xla//xla:shape_util", ] + if_cuda([ "@local_config_cuda//cuda:cudnn_header", ]) + if_static( diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index 6820a5ddd696d3..0a80e85ba83955 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -327,7 +327,7 @@ void ConsiderConstantFoldableNode( std::unordered_map>* constant_control_deps, std::unordered_map>* shape_replacement_map, bool* internal_node_inserted) { - if (!IsConstantFoldable(n, opts.shape_map, opts.consider, + if (!IsConstantFoldable(n, opts.shape_map, opts.consider, opts.max_constant_size_in_bytes, shape_replacement_map)) { return; diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index 711c148aec958e..2c0d2ba5ee7ee9 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -342,6 +342,8 @@ filegroup( "tensor_key.h", "tensor_shape.cc", "tensor_shape.h", + "tensor_shape_expr.cc", + "tensor_shape_expr.h", "tensor_types.h", "tensor_util.h", "tracking_allocator.h", @@ -720,6 +722,23 @@ cc_library( "//tensorflow/core/platform:statusor", "//tensorflow/core/util:overflow", "@eigen_archive//:eigen3", + "@local_xla//xla:shape_util", + ], + alwayslink = 1, +) + +cc_library( + name = "tensor_shape_expr", + srcs = ["tensor_shape_expr.cc"], + hdrs = ["tensor_shape_expr.h"], + visibility = [ + "//tensorflow/core:__pkg__", + "//tensorflow/core/grappler:__subpackages__", + "//tensorflow/core/runtime_fallback:__subpackages__", + "//tensorflow/core/tfrt/utils:__subpackages__", + ], + deps = [ + ":tensor_shape_proto_cc", ], alwayslink = 1, ) @@ -923,6 +942,7 @@ cc_library( ":node_def_util", ":op_def_proto_cc", ":tensor_shape", + ":tensor_shape_expr", ":tensor_shape_proto_cc", "//tensorflow/core/lib/core:errors", "//tensorflow/core/lib/core:status", diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index aefa2a416310e2..42cf77245b8e1b 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -2369,7 +2369,13 @@ absl::Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c, } else if (dim_y.SameHandle(dim_x)) { dims.push_back(dim_x); } else if (!c->ValueKnown(dim_x) && !c->ValueKnown(dim_y)) { - dims.push_back(c->UnknownDim()); + DimensionHandle merged; + absl::Status s = c->Merge(dim_x, dim_y, &merged); + if (s.ok()) { + dims.push_back(merged); + } else { + dims.push_back(c->UnknownDim()); + } } else { if (!incompatible_shape_error) { *out = c->UnknownShape(); diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index b63269f68c3368..5a4fe02f147f4c 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -248,6 +248,10 @@ void InferenceContext::ShapeHandleToProto(ShapeHandle handle, dim_shape->set_size(Value(dim)); } else { dim_shape->set_size(-1); + // Serialize expression if available. + if (DynExpr* expr = DimExpr(dim)) { + expr->ToProto(dim_shape->mutable_expr()); + } } } } @@ -282,6 +286,36 @@ DimensionHandle InferenceContext::NumElements(ShapeHandle s) { } } +DimensionHandle InferenceContext::UnknownDimWithExpr( + std::unique_ptr expr) { + DynExpr* owned = shape_manager_.OwnExpr(std::move(expr)); + return shape_manager_.MakeDim(kUnknownDim, /*dynamic_ratio*/0, owned); +} + +DynExpr* InferenceContext::DimExpr(DimensionHandle d) const { + if (!d.IsSet()) return nullptr; + return d->expr_; +} + +DynExpr* InferenceContext::MakeConstExpr(int64_t v) { + return shape_manager_.OwnExpr(std::make_unique(v)); +} + +DynExpr* InferenceContext::ExprForDim(DimensionHandle d) { + if (!d.IsSet()) return nullptr; + + // If already tagged with expr, use it. + if (DynExpr* e = DimExpr(d)) return e; + + // Known dim -> const expr. + if (ValueKnown(d)) { + return MakeConstExpr(Value(d)); + } + + // Unknown dim with no expr -> cannot form expression. + return nullptr; +} + string InferenceContext::DebugString(ShapeHandle s) { if (RankKnown(s)) { std::vector vals; @@ -293,7 +327,7 @@ string InferenceContext::DebugString(ShapeHandle s) { } string InferenceContext::DebugString(DimensionHandle d) { - return ValueKnown(d) ? strings::StrCat(Value(d)) : "?"; + return ValueKnown(d) ? strings::StrCat(Value(d), strings::StrCat("~",DynamicRatio(d))) : "?"; } string InferenceContext::DebugString() const { @@ -895,7 +929,12 @@ absl::Status InferenceContext::MakeShapeFromPartialTensorShape( for (int i = 0; i < num_dims; ++i) { // -1 is unknown in PartialTensorShape and in InferenceContext, so this size // can be passed directly to MakeDim. - dims[i] = MakeDim(partial_shape.dim_size(i)); + if(i == 0){ + dims[i] = MakeDim(partial_shape.dim_size(i), 1); + } + else { + dims[i] = MakeDim(partial_shape.dim_size(i)); + } } return ReturnCreatedShape(dims, out); } @@ -923,8 +962,38 @@ absl::Status InferenceContext::MakeShapeFromShapeProto( const TensorShapeProto& proto, ShapeHandle* out) { *out = nullptr; TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(proto)); - PartialTensorShape partial_shape(proto); - return MakeShapeFromPartialTensorShape(partial_shape, out); + + if (proto.unknown_rank()) { + *out = UnknownShape(); + return absl::OkStatus(); + } + + std::vector dims; + dims.reserve(proto.dim_size()); + for (int i = 0; i < proto.dim_size(); ++i) { + const auto& dim_proto = proto.dim(i); + if (dim_proto.size() >= 0) { + // Known dimension + dims.push_back(MakeDim(dim_proto.size())); + } else { + // Unknown dimension - check for expression + if (dim_proto.has_expr() && dim_proto.expr().node_type_case() != + ExpressionProto::NODE_TYPE_NOT_SET) { + // Deserialize expression + std::unique_ptr expr = DynExpr::FromProto(dim_proto.expr()); + if (expr) { + DynExpr* owned = shape_manager_.OwnExpr(std::move(expr)); + dims.push_back(shape_manager_.MakeDim(kUnknownDim,/*dynamic_ratio */ 0, owned)); + } else { + dims.push_back(UnknownDim()); + } + } else { + dims.push_back(UnknownDim()); + } + } + } + *out = MakeShape(dims); + return absl::OkStatus(); } absl::Status InferenceContext::GetScalarFromTensor(const Tensor* t, @@ -1030,24 +1099,40 @@ absl::Status InferenceContext::Divide(DimensionHandle dividend, DimensionOrConstant divisor, bool evenly_divisible, DimensionHandle* out) { - const int64_t divisor_value = Value(divisor); - if (divisor_value == 1) { + const bool dividend_known = ValueKnown(dividend); + const bool divisor_known = ValueKnown(divisor); + + // Validate divisor if known. + if (divisor_known && Value(divisor) <= 0) { + return errors::InvalidArgument("Divisor must be positive but is ", + Value(divisor)); + } + // Fast-path: x / 1 = x + if (divisor_known && Value(divisor) == 1) { *out = dividend; - } else if (!ValueKnown(dividend) || - (divisor.dim.IsSet() && !ValueKnown(divisor.dim))) { - *out = UnknownDim(); - } else { + return absl::OkStatus(); + } + // If both known, do numeric divide. + if (dividend_known && divisor_known) { const int64_t v = Value(dividend); - if (divisor_value <= 0) { - return errors::InvalidArgument("Divisor must be positive but is ", - divisor_value); - } - if (evenly_divisible && (v % divisor_value) != 0) { + const int64_t d = Value(divisor); + if (evenly_divisible && (v % d) != 0) { return errors::InvalidArgument( - "Dimension size must be evenly divisible by ", divisor_value, - " but is ", v); + "Dimension size must be evenly divisible by ", d, " but is ", v); } - *out = MakeDim(v / divisor_value); + *out = MakeDim(v / d); + return absl::OkStatus(); + } + // At least one operand unknown: try to build expression. + DynExpr* lhs = ExprForDim(dividend); + DynExpr* rhs = divisor.dim.IsSet() ? ExprForDim(divisor.dim) + : MakeConstExpr(divisor.val); + if (lhs && rhs) { + DynExpr* node = shape_manager_.OwnExpr( + std::make_unique(lhs, rhs)); + *out = shape_manager_.MakeDim(kUnknownDim, /*dynamic_ratio*/0, node); + } else { + *out = UnknownDim(); // Can't form expr. } return absl::OkStatus(); } @@ -1055,26 +1140,41 @@ absl::Status InferenceContext::Divide(DimensionHandle dividend, absl::Status InferenceContext::Add(DimensionHandle first, DimensionOrConstant second, DimensionHandle* out) { - const int64_t first_value = Value(first); - const int64_t second_value = Value(second); - // Special cases. - if (first_value == 0) { + const bool first_known = ValueKnown(first); + const bool second_known = ValueKnown(second); + + // Fast-path: x + 0 = x + if (first_known && Value(first) == 0) { *out = MakeDim(second); - } else if (second_value == 0) { + return absl::OkStatus(); + } + if (second_known && Value(second) == 0) { *out = first; - } else if (first_value == kUnknownDim || second_value == kUnknownDim) { - *out = UnknownDim(); - } else { - // Invariant: Both values are known and positive. Still in run-time we can - // get pair of values which cannot be store in output. Check below will - // report error. We still need to avoid undefined behavior of signed - // overflow and use unsigned addition. - const int64_t sum = static_cast(first_value) + second_value; + return absl::OkStatus(); + } + + // If both known, do numeric add. + if (first_known && second_known) { + const int64_t sum = static_cast(Value(first)) + + static_cast(Value(second)); if (sum < 0) { return errors::InvalidArgument("Dimension size overflow from adding ", - first_value, " and ", second_value); + Value(first), " and ", Value(second)); } *out = MakeDim(sum); + return absl::OkStatus(); + } + + // At least one operand unknown: try to build expression. + DynExpr* lhs = ExprForDim(first); + DynExpr* rhs = + second.dim.IsSet() ? ExprForDim(second.dim) : MakeConstExpr(second.val); + + if (lhs && rhs) { + DynExpr* node = shape_manager_.OwnExpr(std::make_unique(lhs, rhs)); + *out = shape_manager_.MakeDim(kUnknownDim, /*dynamic_ratio*/ 0, node); + } else { + *out = UnknownDim(); // Can't form expr. } return absl::OkStatus(); } @@ -1082,22 +1182,34 @@ absl::Status InferenceContext::Add(DimensionHandle first, absl::Status InferenceContext::Subtract(DimensionHandle first, DimensionOrConstant second, DimensionHandle* out) { - const int64_t first_value = Value(first); - const int64_t second_value = Value(second); - // Special cases. - if (second_value == 0) { + const bool first_known = ValueKnown(first); + const bool second_known = ValueKnown(second); + // Fast-path: x - 0 = x + if (second_known && Value(second) == 0) { *out = first; - } else if (first_value == kUnknownDim || second_value == kUnknownDim) { - *out = UnknownDim(); - } else { - // Invariant: Both values are known, first_value is non-negative, and - // second_value is positive. + return absl::OkStatus(); + } + // If both known, do numeric subtract. + if (first_known && second_known) { + const int64_t first_value = Value(first); + const int64_t second_value = Value(second); if (first_value < second_value) { return errors::InvalidArgument( "Negative dimension size caused by subtracting ", second_value, " from ", first_value); } *out = MakeDim(first_value - second_value); + return absl::OkStatus(); + } + // At least one operand unknown: try to build expression. + DynExpr* lhs = ExprForDim(first); + DynExpr* rhs = + second.dim.IsSet() ? ExprForDim(second.dim) : MakeConstExpr(second.val); + if (lhs && rhs) { + DynExpr* node = shape_manager_.OwnExpr(std::make_unique(lhs, rhs)); + *out = shape_manager_.MakeDim(kUnknownDim, /*dynamic_ratio*/ 0, node); + } else { + *out = UnknownDim(); // Can't form expr. } return absl::OkStatus(); } @@ -1105,21 +1217,31 @@ absl::Status InferenceContext::Subtract(DimensionHandle first, absl::Status InferenceContext::Multiply(DimensionHandle first, DimensionOrConstant second, DimensionHandle* out) { + const bool first_known = ValueKnown(first); + const bool second_known = ValueKnown(second); const int64_t first_value = Value(first); const int64_t second_value = Value(second); - // Special cases. - if (first_value == 0) { + + // Fast-paths for identity and zero cases. + if (first_known && first_value == 0) { *out = first; - } else if (second_value == 0) { + return absl::OkStatus(); + } + if (second_known && second_value == 0) { *out = MakeDim(second); - } else if (first_value == 1) { + return absl::OkStatus(); + } + if (first_known && first_value == 1) { *out = MakeDim(second); - } else if (second_value == 1) { + return absl::OkStatus(); + } + if (second_known && second_value == 1) { *out = first; - } else if (first_value == kUnknownDim || second_value == kUnknownDim) { - *out = UnknownDim(); - } else { - // Invariant: Both values are known and greater than 1. + return absl::OkStatus(); + } + + // If both known, do numeric multiply. + if (first_known && second_known) { const int64_t product = MultiplyWithoutOverflow(first_value, second_value); if (product < 0) { return errors::InvalidArgument( @@ -1127,6 +1249,19 @@ absl::Status InferenceContext::Multiply(DimensionHandle first, first_value, " and ", second_value); } *out = MakeDim(product); + return absl::OkStatus(); + } + + // At least one operand unknown: try to build expression. + DynExpr* lhs = ExprForDim(first); + DynExpr* rhs = + second.dim.IsSet() ? ExprForDim(second.dim) : MakeConstExpr(second.val); + + if (lhs && rhs) { + DynExpr* node = shape_manager_.OwnExpr(std::make_unique(lhs, rhs)); + *out = shape_manager_.MakeDim(kUnknownDim, /*dynamic_ratio*/ 0, node); + } else { + *out = UnknownDim(); // Can't form expr. } return absl::OkStatus(); } diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index 8bfd301d860de1..d49dbf859932ec 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -20,6 +20,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/core/framework/full_type.pb.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor_shape_expr.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/macros.h" @@ -116,13 +117,16 @@ class InferenceContext; class Dimension { private: Dimension(); - Dimension(int64_t value); + Dimension(int64_t value, int64_t dynamic_ratio = 0, DynExpr* expr = nullptr); ~Dimension() {} const int64_t value_; + const int64_t dynamic_ratio_; + DynExpr* expr_; friend class InferenceContext; friend class ShapeManager; + friend class ::tensorflow::grappler::SymbolicShapeManager; Dimension(const Dimension&) = delete; void operator=(const Dimension&) = delete; }; @@ -439,6 +443,9 @@ class InferenceContext { static inline int64_t Value(DimensionOrConstant d) { return d.dim.IsSet() ? d.dim->value_ : d.val; } + static inline int64_t DynamicRatio(DimensionOrConstant d) { + return d.dim->dynamic_ratio_ ; + } static inline bool ValueKnown(DimensionOrConstant d) { return Value(d) != kUnknownDim; } @@ -572,12 +579,26 @@ class InferenceContext { // Returns a new dimension of the given size. The returned value is owned by // this context. - inline DimensionHandle MakeDim(DimensionOrConstant d) { - return shape_manager_.MakeDim(d); + inline DimensionHandle MakeDim(DimensionOrConstant d, int64_t dynamic_ratio = 0) { + return shape_manager_.MakeDim(d, dynamic_ratio); } inline DimensionHandle UnknownDim() { return MakeDim(kUnknownDim); } + // Create a new unknown dimension (size = -1) tagged with a DynExpr. + // The expression is owned by this context's ShapeManager. + DimensionHandle UnknownDimWithExpr(std::unique_ptr expr); + // Return the expression pointer for a dimension, or nullptr if none. + DynExpr* DimExpr(DimensionHandle d) const; + // Creates a constant DynExpr node for the given value. + // The expression is owned by this context's ShapeManager. + DynExpr* MakeConstExpr(int64_t v); + // Returns the Expr representation for the given dimension: + // - If dim has an expr, returns it + // - If dim is known, returns a new Const expr + // - If dim is unknown with no expr, returns nullptr + DynExpr* ExprForDim(DimensionHandle d); + // Returns in a scalar value from an input tensor . The input tensor // must be a 0-dimensional int32 or int64 tensor. Caller must ensure that the // input tensor is not NULL. @@ -743,8 +764,6 @@ class InferenceContext { // Adds new outputs; useful when mutating the graph. absl::Status ExpandOutputs(int new_output_size); - - private: // Creates and stores shapes for use in InferenceContext. class ShapeManager { public: @@ -760,21 +779,31 @@ class InferenceContext { // Returns a new dimension of the given size. The returned value // is owned by this class. - inline DimensionHandle MakeDim(DimensionOrConstant d) { + inline DimensionHandle MakeDim(DimensionOrConstant d, int64_t dynamic_ratio = 0, DynExpr* expr = nullptr) { if (d.dim.IsSet()) { return d.dim; } else { - all_dims_.push_back(new Dimension(d.val)); + all_dims_.push_back(new Dimension(d.val, dynamic_ratio, expr)); return all_dims_.back(); } } + // Takes ownership of an expression and returns a raw pointer to it. + DynExpr* OwnExpr(std::unique_ptr expr) { + if (!expr) return nullptr; + DynExpr* ptr = expr.get(); + all_exprs_.push_back(std::move(expr)); + return ptr; + } private: std::vector all_shapes_; // values are owned. std::vector all_dims_; // values are owned. + std::vector> all_exprs_; // expressions are owned. }; + private: friend class ::tensorflow::grappler::GraphProperties; + friend class ::tensorflow::grappler::SymbolicShapeManager; friend class ShapeInferenceTest; // For testing Relax functions. friend class ShapeInferenceTestutil; // For testing shapes. @@ -888,8 +917,8 @@ class InferenceContext { // ----------------------------------------------------------------------------- // Template and inline method implementations, please ignore -inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {} -inline Dimension::Dimension(int64_t value) : value_(value) { +inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim), dynamic_ratio_(0), expr_(nullptr) {} +inline Dimension::Dimension(int64_t value, int64_t dynamic_ratio, DynExpr* expr) : value_(value), dynamic_ratio_(dynamic_ratio), expr_(expr) { DCHECK(value >= 0 || value == InferenceContext::kUnknownDim) << "Dimension must be non-negative or equal to " "InferenceContext::kUnknownDim but got " diff --git a/tensorflow/core/framework/tensor_shape.cc b/tensorflow/core/framework/tensor_shape.cc index 35c628216ed3c6..1146b4916ddf07 100644 --- a/tensorflow/core/framework/tensor_shape.cc +++ b/tensorflow/core/framework/tensor_shape.cc @@ -26,6 +26,105 @@ limitations under the License. namespace tensorflow { +xla::DynExpr* ExprFromProto(const ExpressionProto& proto) { + switch (proto.node_type_case()) { + case ExpressionProto::kConstantValue: + return xla::DynExpr::_(proto.constant_value()); + + case ExpressionProto::kVariableId: + return xla::DynExpr::V(proto.variable_id()); + + case ExpressionProto::kAddNode: { + const auto& add = proto.add_node(); + return *ExprFromProto(add.lhs()) + *ExprFromProto(add.rhs()); + } + + case ExpressionProto::kSubNode: { + const auto& sub = proto.sub_node(); + return *ExprFromProto(sub.lhs()) - *ExprFromProto(sub.rhs()); + } + + case ExpressionProto::kMulNode: { + const auto& mul = proto.mul_node(); + return *ExprFromProto(mul.lhs()) * *ExprFromProto(mul.rhs()); + } + + case ExpressionProto::kDivNode: { + const auto& div = proto.div_node(); + return *ExprFromProto(div.lhs()) / *ExprFromProto(div.rhs()); + } + + case ExpressionProto::NODE_TYPE_NOT_SET: + default: + return nullptr; + } +} + +void ExprToProto(xla::DynExpr* expr, ExpressionProto* proto) { + auto e = expr->s(); + if (xla::Constant* c = dynamic_cast(e)) { + proto->set_constant_value(c->get_val()); + } else if (xla::Variable* v = dynamic_cast(e)) { + proto->set_variable_id(v->get_id()); + } else if (xla::Add* a = dynamic_cast(e)) { + auto* add_msg = proto->mutable_add_node(); + ExprToProto(a->get_lhs(), add_msg->mutable_lhs()); + ExprToProto(a->get_rhs(), add_msg->mutable_rhs()); + } else if (xla::Mul* m = dynamic_cast(e)) { + auto* mul_msg = proto->mutable_mul_node(); + ExprToProto(m->get_lhs(), mul_msg->mutable_lhs()); + ExprToProto(m->get_rhs(), mul_msg->mutable_rhs()); + } else if (xla::Sub* s = dynamic_cast(e)) { + auto* sub_msg = proto->mutable_sub_node(); + ExprToProto(s->get_lhs(), sub_msg->mutable_lhs()); + ExprToProto(s->get_rhs(), sub_msg->mutable_rhs()); + } else if (xla::Div* d = dynamic_cast(e)) { + auto* div_msg = proto->mutable_div_node(); + ExprToProto(d->get_lhs(), div_msg->mutable_lhs()); + ExprToProto(d->get_rhs(), div_msg->mutable_rhs()); + } +} + +// Independent helper function to handle the recursion +void BuildExprString(xla::DynExpr* e, std::ostringstream& oss) { + if (xla::Constant* c = dynamic_cast(e)) { + oss << c->get_val(); + } else if (xla::Variable* v = dynamic_cast(e)) { + char letter = 'A' + (v->get_id() - 1); + oss << letter; + } else if (xla::Add* a = dynamic_cast(e)) { + oss << "("; + BuildExprString(a->get_lhs(), oss); + oss << " + "; + BuildExprString(a->get_rhs(), oss); + oss << ")"; + } else if (xla::Mul* m = dynamic_cast(e)) { + oss << "("; + BuildExprString(m->get_lhs(), oss); + oss << " * "; + BuildExprString(m->get_rhs(), oss); + oss << ")"; + } else if (xla::Sub* s = dynamic_cast(e)) { + oss << "("; + BuildExprString(s->get_lhs(), oss); + oss << " - "; + BuildExprString(s->get_rhs(), oss); + oss << ")"; + } else if (xla::Div* d = dynamic_cast(e)) { + oss << "("; + BuildExprString(d->get_lhs(), oss); + oss << " / "; + BuildExprString(d->get_rhs(), oss); + oss << ")"; + } +} + +std::string ExprToString(xla::DynExpr* e) { + std::ostringstream oss; + BuildExprString(e, oss); + return oss.str(); +} + // TensorShape and PartialTensorShape should have no fields beyond // TensorShapeRep. In particular, their sizes should be the same. static_assert(sizeof(TensorShapeRep) == sizeof(TensorShape), @@ -152,6 +251,9 @@ TensorShapeBase::TensorShapeBase(const TensorShapeProto& proto) { for (const auto& d : proto.dim()) { AddDim(d.size()); } + for (const auto& e : proto.expressions()) { + AddExpression(ExprFromProto(e)); + } } } @@ -505,6 +607,9 @@ void TensorShapeBase::UnsafeAddDim(int64_t size, template void TensorShapeBase::AppendShape(const TensorShapeBase& shape) { for (auto d : shape) AddDim(d.size); + for (auto e : shape.get_expressions()){ + AddExpression(e); + } } template @@ -585,6 +690,7 @@ template void TensorShapeBase::set_dim(int d, int64_t size) { CHECK_GE(d, 0); CHECK_LT(d, dims()); + if (get_expressions().size() > d) set_expression(d, xla::DynExpr::_(size)); if (!kIsPartial) { CHECK_GE(size, 0); } @@ -611,6 +717,8 @@ void TensorShapeBase::set_dim(int d, int64_t size) { template absl::Status TensorShapeBase::SetDimWithStatus(int d, int64_t size) { + if (get_expressions().size() > d) set_expression(d, xla::DynExpr::_(size)); + if (TF_PREDICT_FALSE(d < 0)) { return errors::InvalidArgument("Index must be non-negative, got ", d); } @@ -731,6 +839,10 @@ void TensorShapeBase::AsProto(TensorShapeProto* proto) const { for (int i = 0; i < dims(); i++) { proto->add_dim()->set_size(dim_size(i)); } + for (int i = 0; i < get_expressions().size(); i++) { + ExpressionProto* eproto = proto->add_expressions(); + ExprToProto(get_expression(i), eproto); + } } } @@ -764,6 +876,11 @@ string TensorShapeRep::DebugString() const { } else { strings::StrAppend(&s, dim); } + if (shape.get_expression(i) != nullptr) { + strings::StrAppend(&s, "("); + strings::StrAppend(&s, ExprToString(shape.get_expression(i))); + strings::StrAppend(&s, ")"); + } } strings::StrAppend(&s, "]"); return s; @@ -787,6 +904,15 @@ string TensorShapeRep::DebugString(const TensorShapeProto& proto) { first = false; } strings::StrAppend(&s, "]"); + strings::StrAppend(&s, "<"); + first = true; + for (const auto& e : proto.expressions()) { + if (!first) strings::StrAppend(&s, ","); + auto exp = ExprFromProto(e); + strings::StrAppend(&s, ExprToString(exp)); + first = false; + } + strings::StrAppend(&s, ">"); return s; } @@ -950,6 +1076,7 @@ absl::Status PartialTensorShape::MergeWith(const PartialTensorShape& shape, return s; } } + result->set_expressions(shape.get_expressions()); return absl::OkStatus(); } diff --git a/tensorflow/core/framework/tensor_shape.h b/tensorflow/core/framework/tensor_shape.h index 0bcf1fc54af844..f271714fc989d6 100644 --- a/tensorflow/core/framework/tensor_shape.h +++ b/tensorflow/core/framework/tensor_shape.h @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/statusor.h" +#include "xla/shape.h" namespace tensorflow { @@ -73,7 +74,36 @@ class TensorShapeRep { std::string DebugString() const; static std::string DebugString(const TensorShapeProto& proto); + void set_expression(int d, xla::DynExpr* expr){ + expressions_[d] = expr; + } + + void AddExpression(xla::DynExpr* expr){ + expressions_.push_back(expr); + } + + // Set the array of dynamic multipliers. + void set_expressions(std::vector exprs) { + expressions_ = exprs; + } + + // Get the array of dynamic multipliers. + std::vector get_expressions() const { + return expressions_; + } + + // Return the multiplier for a specific dynamic dimension. + // -1 if the dimension is not dynamic. + xla::DynExpr* get_expression(int64_t dimension) const { + if (dimension >= expressions_.size()) { + return nullptr; + } + return expressions_[dimension]; + } + protected: + std::vector expressions_; + // Constructable only via TensorShapeBase TensorShapeRep() = default; @@ -710,6 +740,7 @@ absl::Status TensorShape::AsEigenDSizesWithPaddingWithStatus( inline TensorShapeRep::TensorShapeRep(const TensorShapeRep& b) { num_elements_ = b.num_elements_; + expressions_ = b.expressions_; if (b.tag() != REP_OUT_OF_LINE) { memcpy(buf(), b.buf(), sizeof(u_.buf)); // memcpy above Implicitly does: @@ -723,6 +754,7 @@ inline TensorShapeRep::TensorShapeRep(const TensorShapeRep& b) { inline TensorShapeRep::TensorShapeRep(TensorShapeRep&& b) { num_elements_ = b.num_elements_; + expressions_ = b.expressions_; memcpy(buf(), b.buf(), sizeof(u_.buf)); // memcpy above Implicitly does: // set_ndims_byte(b.ndims_byte()); @@ -738,6 +770,8 @@ inline TensorShapeRep::~TensorShapeRep() { inline void TensorShapeRep::operator=(const TensorShapeRep& b) { num_elements_ = b.num_elements_; + expressions_ = b.expressions_; + if (tag() != REP_OUT_OF_LINE && b.tag() != REP_OUT_OF_LINE) { memcpy(buf(), b.buf(), sizeof(u_.buf)); // memcpy above implicitly also does: @@ -753,6 +787,8 @@ inline void TensorShapeRep::operator=(TensorShapeRep&& b) { DestructorOutOfLine(); } num_elements_ = b.num_elements_; + expressions_ = b.expressions_; + memcpy(buf(), b.buf(), sizeof(u_.buf)); // memcpy above Implicitly does: // set_ndims_byte(b.ndims_byte()); diff --git a/tensorflow/core/framework/tensor_shape.proto b/tensorflow/core/framework/tensor_shape.proto index 45d5b78ecbbc4c..f69b4228a7fb31 100644 --- a/tensorflow/core/framework/tensor_shape.proto +++ b/tensorflow/core/framework/tensor_shape.proto @@ -22,6 +22,10 @@ message TensorShapeProto { // Optional name of the tensor dimension. string name = 2; + //Only keep one symbolic expr. + //Symbolic expression for this dimension when size == -1. + //Allows tracking relationships between unknown dimensions. + ExpressionProto expr = 3; }; // Dimensions of the tensor, such as {"input", 30}, {"output", 40} @@ -43,4 +47,38 @@ message TensorShapeProto { // // If true, "dim.size()" must be 0. bool unknown_rank = 3; + + repeated ExpressionProto expressions = 4; + }; + +message ExpressionProto { + oneof node_type { + int32 constant_value = 1; // cons + int32 variable_id = 2; // var + AddNode add_node = 3; // exp + exp + SubNode sub_node = 4; // exp - exp + MulNode mul_node = 5; // exp * exp + DivNode div_node = 6; // exp / exp + } +} + +message AddNode { + ExpressionProto lhs = 1; + ExpressionProto rhs = 2; +} + +message SubNode { + ExpressionProto lhs = 1; + ExpressionProto rhs = 2; +} + +message MulNode { + ExpressionProto lhs = 1; + ExpressionProto rhs = 2; +} + +message DivNode { + ExpressionProto lhs = 1; + ExpressionProto rhs = 2; +} \ No newline at end of file diff --git a/tensorflow/core/framework/tensor_shape_expr.cc b/tensorflow/core/framework/tensor_shape_expr.cc new file mode 100644 index 00000000000000..ef3a6b86bbcf65 --- /dev/null +++ b/tensorflow/core/framework/tensor_shape_expr.cc @@ -0,0 +1,196 @@ +#include "tensorflow/core/framework/tensor_shape_expr.h" + +namespace tensorflow { + +std::unique_ptr DynExpr::Cons(int64_t val) { + return std::make_unique(val); +} + +std::unique_ptr DynExpr::Var(int32_t id) { + return std::make_unique(id); +} + +std::string DynExpr::DebugString() const { + ExpressionProto proto; + ToProto(&proto); + return proto.DebugString(); +} + +static bool EqualsImpl(const DynExpr* a, const DynExpr* b) { + if (a == b) return true; + if (a == nullptr || b == nullptr) return false; + if (a->kind() != b->kind()) return false; + + switch (a->kind()) { + case DynExpr::Kind::kConstant: { + auto* ac = static_cast(a); + auto* bc = static_cast(b); + return ac->value() == bc->value(); + } + case DynExpr::Kind::kVariable: { + auto* av = static_cast(a); + auto* bv = static_cast(b); + return av->id() == bv->id(); + } + case DynExpr::Kind::kAdd: { + auto* aa = static_cast(a); + auto* ba = static_cast(b); + return EqualsImpl(aa->lhs(), ba->lhs()) && + EqualsImpl(aa->rhs(), ba->rhs()); + } + case DynExpr::Kind::kSub: { + auto* as = static_cast(a); + auto* bs = static_cast(b); + return EqualsImpl(as->lhs(), bs->lhs()) && + EqualsImpl(as->rhs(), bs->rhs()); + } + case DynExpr::Kind::kMul: { + auto* am = static_cast(a); + auto* bm = static_cast(b); + return EqualsImpl(am->lhs(), bm->lhs()) && + EqualsImpl(am->rhs(), bm->rhs()); + } + case DynExpr::Kind::kDiv: { + auto* ad = static_cast(a); + auto* bd = static_cast(b); + return EqualsImpl(ad->lhs(), bd->lhs()) && + EqualsImpl(ad->rhs(), bd->rhs()); + } + } + + return false; +} + +bool DynExpr::Equals(const DynExpr* a, const DynExpr* b) { + return EqualsImpl(a, b); +} + +std::unique_ptr DynExpr::FromProto(const ExpressionProto& proto) { + switch (proto.node_type_case()) { + case ExpressionProto::kConstantValue: + return DynExpr::Cons(proto.constant_value()); + case ExpressionProto::kVariableId: + return DynExpr::Var(proto.variable_id()); + case ExpressionProto::kAddNode: { + auto lhs = FromProto(proto.add_node().lhs()); + auto rhs = FromProto(proto.add_node().rhs()); + // Note: These are owning pointers, but ExprAdd takes raw pointers. + // The caller must manage lifetime appropriately. + return std::make_unique(lhs.release(), rhs.release()); + } + case ExpressionProto::kSubNode: { + auto lhs = FromProto(proto.sub_node().lhs()); + auto rhs = FromProto(proto.sub_node().rhs()); + return std::make_unique(lhs.release(), rhs.release()); + } + case ExpressionProto::kMulNode: { + auto lhs = FromProto(proto.mul_node().lhs()); + auto rhs = FromProto(proto.mul_node().rhs()); + return std::make_unique(lhs.release(), rhs.release()); + } + case ExpressionProto::kDivNode: { + auto lhs = FromProto(proto.div_node().lhs()); + auto rhs = FromProto(proto.div_node().rhs()); + return std::make_unique(lhs.release(), rhs.release()); + } + case ExpressionProto::NODE_TYPE_NOT_SET: + default: + return nullptr; + } +} + +DynExpr* SimplifyExpr(DynExpr* expr, + std::vector>* arena) { + if (!expr) return nullptr; + + auto own = [arena](std::unique_ptr e) -> DynExpr* { + DynExpr* ptr = e.get(); + arena->push_back(std::move(e)); + return ptr; + }; + + switch (expr->kind()) { + case DynExpr::Kind::kConstant: + case DynExpr::Kind::kVariable: + return expr; + + case DynExpr::Kind::kAdd: { + auto* add = static_cast(expr); + DynExpr* lhs = SimplifyExpr(add->lhs(), arena); + DynExpr* rhs = SimplifyExpr(add->rhs(), arena); + + // Constant folding + if (lhs->IsConstant() && rhs->IsConstant()) { + return own(DynExpr::Cons(lhs->ConstantValue() + rhs->ConstantValue())); + } + + // x + 0 → x + if (rhs->IsConstant() && rhs->ConstantValue() == 0) return lhs; + if (lhs->IsConstant() && lhs->ConstantValue() == 0) return rhs; + + return own(std::make_unique(lhs, rhs)); + } + + case DynExpr::Kind::kSub: { + auto* sub = static_cast(expr); + DynExpr* lhs = SimplifyExpr(sub->lhs(), arena); + DynExpr* rhs = SimplifyExpr(sub->rhs(), arena); + + // Constant folding + if (lhs->IsConstant() && rhs->IsConstant()) { + return own(DynExpr::Cons(lhs->ConstantValue() - rhs->ConstantValue())); + } + + // x - 0 → x + if (rhs->IsConstant() && rhs->ConstantValue() == 0) return lhs; + + return own(std::make_unique(lhs, rhs)); + } + + case DynExpr::Kind::kMul: { + auto* mul = static_cast(expr); + DynExpr* lhs = SimplifyExpr(mul->lhs(), arena); + DynExpr* rhs = SimplifyExpr(mul->rhs(), arena); + + // Constant folding + if (lhs->IsConstant() && rhs->IsConstant()) { + return own(DynExpr::Cons(lhs->ConstantValue() * rhs->ConstantValue())); + } + + // x * 1 → x + if (rhs->IsConstant() && rhs->ConstantValue() == 1) return lhs; + if (lhs->IsConstant() && lhs->ConstantValue() == 1) return rhs; + + // x * 0 → 0 + if (rhs->IsConstant() && rhs->ConstantValue() == 0) + return own(DynExpr::Cons(0)); + if (lhs->IsConstant() && lhs->ConstantValue() == 0) + return own(DynExpr::Cons(0)); + + return own(std::make_unique(lhs, rhs)); + } + + case DynExpr::Kind::kDiv: { + auto* div = static_cast(expr); + DynExpr* lhs = SimplifyExpr(div->lhs(), arena); + DynExpr* rhs = SimplifyExpr(div->rhs(), arena); + + // Constant folding (avoid div by zero) + if (lhs->IsConstant() && rhs->IsConstant()) { + int64_t r = rhs->ConstantValue(); + if (r != 0) { + return own(DynExpr::Cons(lhs->ConstantValue() / r)); + } + } + + // x / 1 → x + if (rhs->IsConstant() && rhs->ConstantValue() == 1) return lhs; + + return own(std::make_unique(lhs, rhs)); + } + } + + return expr; +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/tensor_shape_expr.h b/tensorflow/core/framework/tensor_shape_expr.h new file mode 100644 index 00000000000000..04974157af854b --- /dev/null +++ b/tensorflow/core/framework/tensor_shape_expr.h @@ -0,0 +1,219 @@ +#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_EXPR_H_ +#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_EXPR_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/framework/tensor_shape.pb.h" + +namespace tensorflow { + +// Forward declarations +class Constant; +class Variable; +class ExprAdd; +class ExprSub; +class ExprMul; +class ExprDiv; + +// DynExpr: Base class for symbolic expressions representing dynamic dimension +// sizes. These expressions form a DAG that tracks how unknown dimensions relate +// to each other through arithmetic operations. +// +// The expression language: +// - Var(sym_id): A symbolic variable representing an unknown dimension +// - Const(k): A known constant value +// - Add/Sub/Mul/Div(lhs, rhs): Binary arithmetic operations +// +// INVARIANT: An unknown dimension is not just -1, it is -1 + Var(sym). +class DynExpr { + public: + enum class Kind : uint8_t { + kConstant, + kVariable, + kAdd, + kSub, + kMul, + kDiv, + }; + + virtual ~DynExpr() = default; + + virtual Kind kind() const = 0; + virtual void ToProto(ExpressionProto* proto) const = 0; + + virtual bool IsConstant() const { return false; } + virtual int64_t ConstantValue() const { return 0; } + + // Factory methods - return owning pointers + static std::unique_ptr Cons(int64_t val); + static std::unique_ptr Var(int32_t var_id); + + // Structural equality check + static bool Equals(const DynExpr* a, const DynExpr* b); + + // Build from proto (owns all returned nodes) + static std::unique_ptr FromProto(const ExpressionProto& proto); + + // Debug representation + std::string DebugString() const; + + protected: + DynExpr() = default; +}; + +// Constant expression node: represents a known integer value +class Constant final : public DynExpr { + public: + explicit Constant(int64_t value) : value_(value) {} + + Kind kind() const override { return Kind::kConstant; } + void ToProto(ExpressionProto* proto) const override { + proto->set_constant_value(value_); + } + + bool IsConstant() const override { return true; } + int64_t ConstantValue() const override { return value_; } + + int64_t value() const { return value_; } + + private: + int64_t value_; +}; + +// Variable expression node: represents a symbolic unknown dimension +class Variable final : public DynExpr { + public: + explicit Variable(int32_t id) : id_(id) {} + + Kind kind() const override { return Kind::kVariable; } + void ToProto(ExpressionProto* proto) const override { + proto->set_variable_id(id_); + } + + int32_t id() const { return id_; } + + private: + int32_t id_; +}; + +// Addition expression node +class ExprAdd final : public DynExpr { + public: + ExprAdd(DynExpr* lhs, DynExpr* rhs) : lhs_(lhs), rhs_(rhs) {} + + Kind kind() const override { return Kind::kAdd; } + void ToProto(ExpressionProto* proto) const override { + auto* add_msg = proto->mutable_add_node(); + lhs_->ToProto(add_msg->mutable_lhs()); + rhs_->ToProto(add_msg->mutable_rhs()); + } + + bool IsConstant() const override { + return lhs_->IsConstant() && rhs_->IsConstant(); + } + int64_t ConstantValue() const override { + return lhs_->ConstantValue() + rhs_->ConstantValue(); + } + + DynExpr* lhs() const { return lhs_; } + DynExpr* rhs() const { return rhs_; } + + private: + DynExpr* lhs_; + DynExpr* rhs_; +}; + +// Subtraction expression node +class ExprSub final : public DynExpr { + public: + ExprSub(DynExpr* lhs, DynExpr* rhs) : lhs_(lhs), rhs_(rhs) {} + + Kind kind() const override { return Kind::kSub; } + void ToProto(ExpressionProto* proto) const override { + auto* sub_msg = proto->mutable_sub_node(); + lhs_->ToProto(sub_msg->mutable_lhs()); + rhs_->ToProto(sub_msg->mutable_rhs()); + } + + bool IsConstant() const override { + return lhs_->IsConstant() && rhs_->IsConstant(); + } + int64_t ConstantValue() const override { + return lhs_->ConstantValue() - rhs_->ConstantValue(); + } + + DynExpr* lhs() const { return lhs_; } + DynExpr* rhs() const { return rhs_; } + + private: + DynExpr* lhs_; + DynExpr* rhs_; +}; + +// Multiplication expression node +class ExprMul final : public DynExpr { + public: + ExprMul(DynExpr* lhs, DynExpr* rhs) : lhs_(lhs), rhs_(rhs) {} + + Kind kind() const override { return Kind::kMul; } + void ToProto(ExpressionProto* proto) const override { + auto* mul_msg = proto->mutable_mul_node(); + lhs_->ToProto(mul_msg->mutable_lhs()); + rhs_->ToProto(mul_msg->mutable_rhs()); + } + + bool IsConstant() const override { + return lhs_->IsConstant() && rhs_->IsConstant(); + } + int64_t ConstantValue() const override { + return lhs_->ConstantValue() * rhs_->ConstantValue(); + } + + DynExpr* lhs() const { return lhs_; } + DynExpr* rhs() const { return rhs_; } + + private: + DynExpr* lhs_; + DynExpr* rhs_; +}; + +// Division expression node +class ExprDiv final : public DynExpr { + public: + ExprDiv(DynExpr* lhs, DynExpr* rhs) : lhs_(lhs), rhs_(rhs) {} + + Kind kind() const override { return Kind::kDiv; } + void ToProto(ExpressionProto* proto) const override { + auto* div_msg = proto->mutable_div_node(); + lhs_->ToProto(div_msg->mutable_lhs()); + rhs_->ToProto(div_msg->mutable_rhs()); + } + + bool IsConstant() const override { + return lhs_->IsConstant() && rhs_->IsConstant(); + } + int64_t ConstantValue() const override { + int64_t r = rhs_->ConstantValue(); + return (r == 0) ? 0 : lhs_->ConstantValue() / r; + } + + DynExpr* lhs() const { return lhs_; } + DynExpr* rhs() const { return rhs_; } + + private: + DynExpr* lhs_; + DynExpr* rhs_; +}; + +// Simplify an expression tree: constant folding and algebraic identities. +// Returns a NEW expression (does not mutate input). +// The arena parameter is used to allocate nodes that will be owned externally. +DynExpr* SimplifyExpr(DynExpr* expr, + std::vector>* arena); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_EXPR_H_ diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 613b12bb18ae3a..cecf7459ab5f4a 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/framework/tensor_shape_expr.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/grappler/costs/utils.h" #include "tensorflow/core/grappler/mutable_graph_view.h" @@ -189,6 +190,9 @@ class DisjointSet { absl::Status Merge(Handle x, Handle y); const typename HandleToObject::Object GetMergedValue(Handle value); + // Returns a pointer that uniquely identifies the set containing `value`. + // This can be used as a stable key for associating metadata with a set. + void* RootId(Handle value) { return static_cast(Find(value)); } private: // All the handles that belong to the same set are part of the same tree, and @@ -1148,32 +1152,32 @@ class SymbolicShapeRefiner { } struct ShapeId { - const NodeDef* node; + std::string node_name; int port_id; friend bool operator==(const ShapeId& lhs, const ShapeId& rhs) { - return lhs.node == rhs.node && lhs.port_id == rhs.port_id; + return lhs.node_name == rhs.node_name && lhs.port_id == rhs.port_id; } template friend H AbslHashValue(H h, const ShapeId& s) { - return H::combine(std::move(h), s.node, s.port_id); + return H::combine(std::move(h), s.node_name, s.port_id); } }; struct DimId { - const NodeDef* node; + std::string node_name; int port_id; int dim_index; friend bool operator==(const DimId& lhs, const DimId& rhs) { - return lhs.node == rhs.node && lhs.port_id == rhs.port_id && + return lhs.node_name == rhs.node_name && lhs.port_id == rhs.port_id && lhs.dim_index == rhs.dim_index; } template friend H AbslHashValue(H h, const DimId& d) { - return H::combine(std::move(h), d.node, d.port_id, d.dim_index); + return H::combine(std::move(h), d.node_name, d.port_id, d.dim_index); } }; @@ -1391,7 +1395,7 @@ class SymbolicShapeRefiner { // Return the one ShapeHandle used to denote a fully unknown shape for a node // output. ShapeHandle GetUnknownOutputShape(const NodeDef* node, int index) { - ShapeId id{node, index}; + ShapeId id{node->name(), index}; auto it = unknown_shapes_.find(id); if (it != unknown_shapes_.end()) { return it->second; @@ -1405,16 +1409,38 @@ class SymbolicShapeRefiner { // node output. DimensionHandle GetUnknownOutputDim(const NodeDef* node, int index, int dim_id) { - DimId id{node, index, dim_id}; + DimId id{node->name(), index, dim_id}; auto it = unknown_dims_.find(id); if (it != unknown_dims_.end()) { return it->second; } InferenceContext* c = GetContext(node); - DimensionHandle dim = c->UnknownDim(); + int var_id = GetOrCreateStableVarId(id); + DimensionHandle dim; + if (node->op() == "_Arg") { + var_id *= -1; + // var_id would be minus when it's argument. + dim = c->UnknownDimWithExpr(DynExpr::Var(var_id)); + } else { + dim = c->UnknownDimWithExpr(DynExpr::Var(var_id)); + } + VLOG(1) << "[EXPR] GetUnknownOutputDim: node=" << node->name() + << " out=" << index << " dim=" << dim_id << " -> Var(" << var_id + << ")"; + // Create an unknown dim with Var(var_id) expression. unknown_dims_[id] = dim; return dim; } + // Get or create a stable integer variable ID for a given DimId. + int GetOrCreateStableVarId(const DimId& id) { + auto it = stable_var_ids_.find(id); + if (it != stable_var_ids_.end()) { + return it->second; + } + int var_id = next_var_id_++; + stable_var_ids_[id] = var_id; + return var_id; + } // Returns true if all the output tensors have known values. bool AllOutputValuesKnown(NodeContext* c) { @@ -1712,7 +1738,50 @@ class SymbolicShapeRefiner { // but instantiate a new UnknownDim to prevent incorrect symbolic shape // inference through UnknownDim from Const. InferenceContext* ic = c->inference_context.get(); + const std::string& op = node.op(); + const bool is_bin = + (op == "Sub" || op == "Add" || op == "Mul" || op == "Div"); if (!is_fed) { + if (is_bin) { + if (c->input_tensors_as_shapes_to_propagate.size() < 2) + return absl::OkStatus(); + auto va = c->input_tensors_as_shapes_to_propagate[0]; + auto vb = c->input_tensors_as_shapes_to_propagate[1]; + + if (va.SameHandle(tensorflow::shape_inference::ShapeHandle()) || + vb.SameHandle(tensorflow::shape_inference::ShapeHandle())) { + return absl::OkStatus(); + } + + if (!ic->RankKnown(va) || !ic->RankKnown(vb)) return absl::OkStatus(); + if (ic->Rank(va) != ic->Rank(vb)) return absl::OkStatus(); + + std::vector out_elems; + out_elems.reserve(ic->Rank(va)); + + for (int i = 0; i < ic->Rank(va); ++i) { + auto da = ic->Dim(va, i); + auto db = ic->Dim(vb, i); + + tensorflow::shape_inference::DimensionHandle r; + if (op == "Sub") + TF_RETURN_IF_ERROR(ic->Subtract(da, db, &r)); + else if (op == "Add") + TF_RETURN_IF_ERROR(ic->Add(da, db, &r)); + else if (op == "Mul") + TF_RETURN_IF_ERROR(ic->Multiply(da, db, &r)); + else + TF_RETURN_IF_ERROR( + ic->Divide(da, db, /*evenly_divisible=*/false, &r)); + out_elems.push_back(r); + } + c->output_tensors_as_shapes.resize(1); + c->output_tensors_as_shapes[0] = ic->MakeShape(out_elems); + // @TODO: Check if we need to do anything with output_tensor_protos. + // S.t c->output_tensor_protos[0] = nullptr; + return absl::OkStatus(); + } + if (IsConstant(node)) { const TensorProto& tensor_proto = node.attr().at("value").tensor(); c->output_tensor_protos.resize(1); @@ -1759,7 +1828,7 @@ class SymbolicShapeRefiner { for (int i = 0; i < c->inference_context->num_inputs(); ++i) { c->output_tensors_as_shapes[i] = c->inference_context->input(i); } - } else if (node.op() == "ConcatV2") { + } else if (op == "ConcatV2") { bool valid = true; ShapeHandle result; for (int i = 0; i < ic->num_inputs() - 1; ++i) { @@ -1919,6 +1988,75 @@ class SymbolicShapeRefiner { return absl::OkStatus(); } + absl::Status CanonicalizeOutputDims(const NodeDef* node) { + NodeContext* ctx = GetNodeContext(node); + if (!ctx) return absl::OkStatus(); + + InferenceContext* ic = ctx->inference_context.get(); + for (int out = 0; out < ic->num_outputs(); ++out) { + ShapeHandle s = ic->output(out); + + if (!ic->RankKnown(s) && node->op() == "_Arg") { + // Treat batched function arguments as vectors with batch dim at dim0. + DimensionHandle d0 = GetUnknownOutputDim(node, out, /*dim_index=*/0); + ShapeHandle vec = ic->MakeShape({d0}); + ic->set_output(out, vec); + s = vec; + } + + if (!ic->RankKnown(s)){ + //if Rank is not realized yet, get it from attr. + auto it = node->attr().find("_output_shapes"); + if (it != node->attr().end() && it->second.list().shape_size()>0){ + const TensorShapeProto& proto = it->second.list().shape(out); + + std::vector dims; + dims.reserve(proto.dim_size()); + + for(int d=0; d=0){ + dims.push_back(ic->MakeDim(size)); + } else { + dims.push_back(GetUnknownOutputDim(node, out, d)); + } + } + s = ic->MakeShape(dims); + ic->set_output(out,s); + }else { + VLOG(1) << "RANK still unknown." << node->name(); + continue; + } + } + bool changed = false; + std::vector dims; + dims.reserve(ic->Rank(s)); + for (int d = 0; d < ic->Rank(s); ++d) { + DimensionHandle dim = ic->Dim(s, d); + const int64_t v = ic->Value(dim); + // Keep concrete dims. + if (v >= 0) { + dims.push_back(dim); + continue; + } + // If already tagged with expr, keep it. + if (ic->DimExpr(dim) != nullptr) { + dims.push_back(dim); + continue; + } + // Canonicalize ALL unknown dims. + DimensionHandle canon = GetUnknownOutputDim(node, out, d); + changed |= !dim.SameHandle(canon); + dims.push_back(canon); + } + if (changed) { + ShapeHandle new_s = ic->MakeShape(dims); + ic->set_output(out, new_s); + } + } + return absl::OkStatus(); + } + absl::Status InferShapes(const NodeDef& node, NodeContext* c) { // Infer the shapes of output tensors. if (!c->op_data || c->op_data->shape_inference_fn == nullptr || @@ -1940,8 +2078,8 @@ class SymbolicShapeRefiner { status.Update(SetUnknownShape(&node, output_port)); } } - // Update NodeContext output fields after shape inference function runs. + status.Update(CanonicalizeOutputDims(&node)); status.Update(MaybeUpdateNodeContextOutput(node, is_fed, c)); return status; @@ -2048,6 +2186,9 @@ class SymbolicShapeRefiner { absl::flat_hash_map node_to_context_; absl::flat_hash_map unknown_shapes_; absl::flat_hash_map unknown_dims_; + // Stable variable IDs for canonical dimension symbols. + absl::flat_hash_map stable_var_ids_; + int next_var_id_ = 1; // Store function instantiations only for valid function. If function // instantiation failed it will have an `absl::nullopt`. absl::flat_hash_map> @@ -2080,8 +2221,9 @@ class SymbolicShapeManager { if (InferenceContext::Rank(s1) > 0 && InferenceContext::Rank(s2) > 0) { CHECK_EQ(InferenceContext::Rank(s1), InferenceContext::Rank(s2)); for (int i = 0; i < InferenceContext::Rank(s1); ++i) { - TF_RETURN_IF_ERROR(dims_.Merge(InferenceContext::DimKnownRank(s1, i), - InferenceContext::DimKnownRank(s2, i))); + TF_RETURN_IF_ERROR( + MergeDimsWithExpr(InferenceContext::DimKnownRank(s1, i), + InferenceContext::DimKnownRank(s2, i))); } } return absl::OkStatus(); @@ -2090,7 +2232,7 @@ class SymbolicShapeManager { if (!d1.IsSet() || !d2.IsSet()) { return absl::OkStatus(); } - return dims_.Merge(d1, d2); + return MergeDimsWithExpr(d1, d2); } void AsTensorProperties(const ShapeHandle& shape, const DataType& type, @@ -2104,7 +2246,23 @@ class SymbolicShapeManager { shape_inference::DimensionHandle dim = InferenceContext::DimKnownRank(actual_shape, j); int64_t d = dims_.GetMergedValue(dim); - properties->mutable_shape()->add_dim()->set_size(d); + auto* out_dim = properties->mutable_shape()->add_dim(); + out_dim->set_size(d < 0 ? -1 : d); + // Serialize expression for unknown dims + if (out_dim->size() < 0) { + void* root = dims_.RootId(dim); + DynExpr* expr = nullptr; + if (auto it = dim_root_expr_.find(root); it != dim_root_expr_.end()) { + expr = it->second; + } else { + // Fallback to checking the dim handle directly + expr = GetExprFromDimHandle(dim); + } + if (expr != nullptr) { + expr->ToProto(out_dim->mutable_expr()); + // TODO: Apply simplification? + } + } } } } @@ -2132,7 +2290,119 @@ class SymbolicShapeManager { } private: + // Get the variable ID from an expression, or -1 if not a variable. + static int32_t GetVarId(const DynExpr* e) { + if (!e || e->kind() != DynExpr::Kind::kVariable) return -1; + return static_cast(e)->id(); + } + + static bool IsConst(const DynExpr* e) { + return e && e->kind() == DynExpr::Kind::kConstant; + } + + static bool IsVar(const DynExpr* e) { + return e && e->kind() == DynExpr::Kind::kVariable; + } + + static bool IsPlaceHolder(const DynExpr* e) { + if (!e) return false; + if (e->kind() != DynExpr::Kind::kVariable) return false; + return static_cast(e)->id() < 0; + } + + static bool IsCompound(const DynExpr* e) { + if (!e) return false; + switch (e->kind()) { + case DynExpr::Kind::kAdd: + case DynExpr::Kind::kSub: + case DynExpr::Kind::kMul: + case DynExpr::Kind::kDiv: + return true; + default: + return false; + } + } + + // Ranking: Const > Arg_ > Compound > Var > null + static int InfoScore(const DynExpr* e) { + if (!e) return 0; + if (IsConst(e)) return 4; + if (IsPlaceHolder(e)) return 3; + if (IsCompound(e)) return 2; + if (IsVar(e)) return 1; + return 1; // fallback (shouldn't happen) + } + + // Prefer "more informative" but keep deterministic tie-break. + static DynExpr* PreferMoreInformative(DynExpr* a, DynExpr* b) { + if (a == b) return a; + const int sa = InfoScore(a); + const int sb = InfoScore(b); + if (sa > sb) return a; + if (sb > sa) return b; + // Same score: keep stable choice. + return a; + } + + // Get the expr pointer from a dimension handle (accesses private member). + static DynExpr* GetExprFromDimHandle(const DimensionHandle& d) { + if (!d.IsSet()) return nullptr; + return d->expr_; + } + + absl::Status MergeDimsWithExpr(DimensionHandle d1, DimensionHandle d2) { + if (!d1.IsSet() || !d2.IsSet()) return absl::OkStatus(); + + void* r1 = dims_.RootId(d1); + void* r2 = dims_.RootId(d2); + + // Fetch best-known expr for each set. + auto get_best = [&](void* r, DimensionHandle d) -> DynExpr* { + auto it = dim_root_expr_.find(r); + if (it != dim_root_expr_.end()) return it->second; + return GetExprFromDimHandle(d); // may be null + }; + + DynExpr* e1 = get_best(r1, d1); + DynExpr* e2 = get_best(r2, d2); + + // If already in same UF set, just keep the most informative expr. + if (r1 == r2) { + DynExpr* existing = nullptr; + if (auto it = dim_root_expr_.find(r1); it != dim_root_expr_.end()) { + existing = it->second; + } + DynExpr* chosen = PreferMoreInformative(existing, + PreferMoreInformative(e1, e2)); + if (chosen) dim_root_expr_[r1] = chosen; // keep or upgrade + return absl::OkStatus(); + } + + // Perform UF merge (rank + path compression inside DisjointSet). + TF_RETURN_IF_ERROR(dims_.Merge(d1, d2)); + + // New root after merge. + void* new_root = dims_.RootId(d1); + + // Choose best expr across both sets. + DynExpr* chosen = PreferMoreInformative(e1, e2); + + // Remove stale root keys (only the old roots). + dim_root_expr_.erase(r1); + dim_root_expr_.erase(r2); + + // Preserve any expr already stored at new_root (rare but safe). + if (auto it = dim_root_expr_.find(new_root); it != dim_root_expr_.end()) { + chosen = PreferMoreInformative(it->second, chosen); + } + + if (chosen) dim_root_expr_[new_root] = chosen; + + return absl::OkStatus(); + } DisjointSet shapes_; + // Map from union-find root pointer to the best expression for that set. + absl::flat_hash_map dim_root_expr_; DisjointSet dims_; }; diff --git a/tensorflow/core/kernels/padding_fifo_queue.cc b/tensorflow/core/kernels/padding_fifo_queue.cc index 3b50099fb9997c..5cf8e4a6cfb68f 100644 --- a/tensorflow/core/kernels/padding_fifo_queue.cc +++ b/tensorflow/core/kernels/padding_fifo_queue.cc @@ -401,6 +401,10 @@ std::vector PaddingFIFOQueue::ConvertShapesPartialDimensionsToZero( const PartialTensorShape& partial = partial_shapes[i]; TensorShape& shape = shapes[i]; for (int64_t s : partial.dim_sizes()) shape.AddDim(s < 0 ? 0 : s); + for (auto e : partial.get_expressions()){ + shape.AddExpression( + e->is_constant() && e->get_val() < 0 ? xla::DynExpr::zero : e); + } } return shapes; } diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc index 4523126f9a71bd..8dc2b49e87e09a 100644 --- a/tensorflow/core/kernels/strided_slice_op.cc +++ b/tensorflow/core/kernels/strided_slice_op.cc @@ -309,6 +309,8 @@ class StridedSliceAssignOp : public OpKernel { bool is_simple_slice = true; absl::InlinedVector begin; absl::InlinedVector end; + absl::InlinedVector begin_expr; + absl::InlinedVector end_expr; absl::InlinedVector strides; Tensor* old_lhs = nullptr; @@ -353,7 +355,7 @@ class StridedSliceAssignOp : public OpKernel { old_lhs->shape(), begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask, &processing_shape, &final_shape, &is_identity, &is_simple_slice, &slice_dim0, - &begin, &end, &strides, &shape_spec)); + &begin, &end, &strides, &begin_expr, &end_expr, &shape_spec)); if (processing_shape.num_elements() > 0) { const Tensor& input = context->input(4); diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 8d53c6dbb38425..a9e19b6d7f5935 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -210,6 +210,11 @@ absl::Status SetOutputShapeForReshape(InferenceContext* c) { for (int32_t i = 0; i < c->Rank(out); ++i) { DimensionHandle dim = c->Dim(out, i); if (!c->ValueKnown(dim)) { + if (c->DimExpr(dim) != nullptr) { + TF_RETURN_IF_ERROR( + c->Multiply(known_out_elems, dim, &known_out_elems)); + continue; + } if (out_unknown_idx >= 0) { too_many_unknown = true; break; @@ -228,6 +233,11 @@ absl::Status SetOutputShapeForReshape(InferenceContext* c) { for (int32_t i = 0; i < c->Rank(in); ++i) { DimensionHandle dim = c->Dim(in, i); if (!c->ValueKnown(dim)) { + if (c->DimExpr(dim) != nullptr) { + TF_RETURN_IF_ERROR( + c->Multiply(known_in_elems, dim, &known_in_elems)); + continue; + } if (in_unknown_idx >= 0) { too_many_unknown = true; break; diff --git a/tensorflow/core/util/strided_slice_op.cc b/tensorflow/core/util/strided_slice_op.cc index 93c5a7e9818ae2..26712412a54b1e 100644 --- a/tensorflow/core/util/strided_slice_op.cc +++ b/tensorflow/core/util/strided_slice_op.cc @@ -22,6 +22,8 @@ limitations under the License. #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/lib/core/status.h" +#include "xla/shape.h" + namespace tensorflow { namespace { @@ -55,6 +57,8 @@ struct StridedSliceDenseSpec { bool end_valid; absl::InlinedVector& begin; absl::InlinedVector& end; + absl::InlinedVector& begin_expr; + absl::InlinedVector& end_expr; absl::InlinedVector& strides; // This vector helps construct the final shape of the slice. // The final tensor is reduced in rank whenever a single index e.g. foo[3] @@ -97,6 +101,8 @@ static absl::Status BuildDenseSpec(const StridedSliceSparseSpec& sparse, // to remove any ellipsis dense->begin.resize(dense->dims); dense->end.resize(dense->dims); + dense->begin_expr.resize(dense->dims); + dense->end_expr.resize(dense->dims); dense->strides.resize(dense->dims); dense->input_shape_gather_indices_sparse.resize(dense->dims); // What indices to get the final shape from. @@ -127,6 +133,8 @@ static absl::Status BuildDenseSpec(const StridedSliceSparseSpec& sparse, for (; full_index < next_index; full_index++) { // new_axis' aren't real axis so you have to skip dense->begin[full_index] = dense->end[full_index] = 0; + dense->begin_expr[full_index] = dense->end_expr[full_index] = + xla::DynExpr::zero; dense->strides[full_index] = 1; dense->begin_mask |= (1 << full_index); dense->end_mask |= (1 << full_index); @@ -150,9 +158,13 @@ static absl::Status BuildDenseSpec(const StridedSliceSparseSpec& sparse, // Gather slicing spec into appropriate index if (begin_flat != nullptr) { dense->begin[full_index] = internal::SubtleMustCopy(begin_flat[i]); + dense->begin_expr[full_index] = + xla::DynExpr::_(dense->begin[full_index]); } if (end_flat != nullptr) { dense->end[full_index] = internal::SubtleMustCopy(end_flat[i]); + dense->end_expr[full_index] = + xla::DynExpr::_(dense->end[full_index]); } dense->strides[full_index] = internal::SubtleMustCopy(strides_flat[i]); @@ -193,10 +205,29 @@ absl::Status ValidateStridedSliceOp( absl::InlinedVector* begin, absl::InlinedVector* end, absl::InlinedVector* strides, + absl::InlinedVector* begin_expr, + absl::InlinedVector* end_expr, StridedSliceShapeSpec* shape_spec) { + absl::InlinedVector b; + absl::InlinedVector e; + + // HACK + if (begin_expr == nullptr) { + for (int i : *begin) { + b.push_back(xla::DynExpr::_(i)); + } + begin_expr = &b; + } + if (end_expr == nullptr) { + for (int i : *end) { + e.push_back(xla::DynExpr::_(i)); + } + end_expr = &e; + } + if (input_shape.unknown_rank()) { // Note: If the rank is unknown, "input_shape.dims()" is -1. - return errors::InvalidArgument("Unexpected input_shape with unknown rank"); + return errors::InvalidArgument("Unexpected input_shape with unknown rank"); } const bool begin_is_wrong = @@ -271,6 +302,7 @@ absl::Status ValidateStridedSliceOp( // we need to produce the missing begin_mask for the first two // dimensions i.e. from begin_mask_spec=0, end_mask_spec=2 // we achieve begin_mask=6, end_mask=7 + StridedSliceDenseSpec dense_spec = {input_shape.dims(), 0 /* begin_mask */, 0 /* end_mask */, @@ -278,6 +310,8 @@ absl::Status ValidateStridedSliceOp( false /* end_valid */, *begin, *end, + *begin_expr, + *end_expr, *strides}; if (strides_tensor.dtype() == DT_INT32) { @@ -301,6 +335,11 @@ absl::Status ValidateStridedSliceOp( int64_t& end_i = (*end)[i]; int64_t& stride_i = (*strides)[i]; int64_t dim_i = input_shape.dim_size(i); + auto dim_exprs = input_shape.get_expressions(); + + xla::DynExpr* dim_i_expr = + i < dim_exprs.size() ? dim_exprs[i] : xla::DynExpr::_(dim_i); + if (stride_i == 0) { return errors::InvalidArgument("strides[", i, "] must be non-zero"); } @@ -315,15 +354,36 @@ absl::Status ValidateStridedSliceOp( const std::array valid_range = { {stride_i > 0 ? 0 : -1, stride_i > 0 ? dim_i : dim_i - 1}}; + const std::array valid_range_expr = { + {stride_i > 0 ? xla::DynExpr::zero : xla::DynExpr::_(-1), + stride_i > 0 ? dim_i_expr : (*dim_i_expr - *xla::DynExpr::one)->s()}}; + auto canonical = [stride_i, dim_i, masks, valid_range](int64_t x, int c) { if (masks[c]) { return stride_i > 0 ? valid_range[c] : valid_range[(c + 1) & 1]; } else { int64_t x_fwd = x < 0 ? dim_i + x : x; // make negative indices positive - return x_fwd < valid_range[0] - ? valid_range[0] - : x_fwd > valid_range[1] ? valid_range[1] : x_fwd; + return x_fwd < valid_range[0] ? valid_range[0] + : x_fwd > valid_range[1] ? valid_range[1] + : x_fwd; + } + }; + auto canonical_expr = [stride_i, dim_i, masks, valid_range, + valid_range_expr, dim_i_expr](int64_t x, int c) { + if (masks[c]) { + return stride_i > 0 ? valid_range_expr[c] + : valid_range_expr[(c + 1) & 1]; + } else { + int64_t x_fwd = + x < 0 ? dim_i + x : x; // make negative indices positive + xla::DynExpr* x_expr = xla::DynExpr::_(x); + xla::DynExpr* x_fwd_expr = + x < 0 ? (*dim_i_expr + *x_expr) + : x_expr; // make negative indices positive + return x_fwd < valid_range[0] ? valid_range_expr[0] + : x_fwd > valid_range[1] ? valid_range_expr[1] + : x_fwd_expr; } }; if (shrink_i && stride_i <= 0) { @@ -341,8 +401,15 @@ absl::Status ValidateStridedSliceOp( // and canonical puts these to n-1 and 0, which implies a degenerate // interval. Fortunately, it is now safe to re-create end as begin+1. int64_t x_fwd = begin_i < 0 ? dim_i + begin_i : begin_i; + xla::DynExpr* x_fwd_expr = begin_i < 0 + ? (*dim_i_expr + *(*begin_expr)[i])->s() + : (*begin_expr)[i]; begin_i = x_fwd; end_i = begin_i + 1; + + (*begin_expr)[i] = x_fwd_expr; + (*end_expr)[i] = (*(*begin_expr)[i] + *xla::DynExpr::one)->s(); + if (x_fwd < 0 || x_fwd >= dim_i) { return errors::InvalidArgument( "slice index ", begin_i, " of dimension ", i, " out of bounds."); @@ -350,6 +417,12 @@ absl::Status ValidateStridedSliceOp( } else { begin_i = canonical(begin_i, 0); end_i = canonical(end_i, 1); + if (begin_expr) { + (*begin_expr)[i] = canonical_expr(begin_i, 0)->s(); + } + if (end_expr) { + (*end_expr)[i] = canonical_expr(end_i, 1)->s(); + } } // Update optimization values bool take_all_in_dimension = @@ -362,14 +435,17 @@ absl::Status ValidateStridedSliceOp( } // Compute the processing shape (the intermediate Eigen will produce) int64_t interval_length; + xla::DynExpr* interval_length_expr; bool known_interval = false; if (dense_spec.begin_valid && dense_spec.end_valid) { interval_length = end_i - begin_i; + interval_length_expr = (*(*end_expr)[i] - *(*begin_expr)[i])->s(); known_interval = true; } else if (shrink_i) { // The dimension is still known as 1 for the processing_shape, but will be // discarded for the final shape. interval_length = 1; + interval_length_expr = xla::DynExpr::one; known_interval = true; } else if (begin_and_end_masked) { // Even if we don't have values for begin or end, we do know that this @@ -378,25 +454,34 @@ absl::Status ValidateStridedSliceOp( if (dim_i >= 0) { if (stride_i < 0) { interval_length = -dim_i; + interval_length_expr = (-1 * (*dim_i_expr))->s(); } else { interval_length = dim_i; + interval_length_expr = dim_i_expr; } known_interval = true; } } if (known_interval) { int64_t size_i; + xla::DynExpr* size_i_expr; // Hold zero if the interval is degenerate, otherwise account for // remainder if (interval_length == 0 || ((interval_length < 0) != (stride_i < 0))) { size_i = 0; + size_i_expr = xla::DynExpr::zero; } else { size_i = interval_length / stride_i + (interval_length % stride_i != 0 ? 1 : 0); + size_i_expr = *(*interval_length_expr / stride_i) + + *(interval_length % stride_i != 0 ? xla::DynExpr::one + : xla::DynExpr::zero); } processing_shape->AddDim(size_i); + processing_shape->AddExpression(size_i_expr->s()); } else { processing_shape->AddDim(-1); + processing_shape->AddExpression(xla::DynExpr::_(-1)); } } @@ -425,12 +510,15 @@ absl::Status ValidateStridedSliceOp( dense_spec.final_shape_gather_indices_sparse[dense_dim]; if (gather_index >= 0) { final_shape->AddDim(processing_shape->dim_size(gather_index)); + final_shape->AddExpression( + processing_shape->get_expression(gather_index)); if (shape_spec != nullptr) { shape_spec->output_to_sparse_mapping.push_back(sparse_index); shape_spec->output_to_processing_mapping.push_back(gather_index); } } else if (gather_index == kNewAxis) { final_shape->AddDim(1); + final_shape->AddExpression(xla::DynExpr::one); if (shape_spec != nullptr) { shape_spec->output_to_sparse_mapping.push_back(-1); shape_spec->output_to_processing_mapping.push_back(-1); @@ -451,6 +539,8 @@ absl::Status ValidateStridedSliceOp( absl::InlinedVector* begin, absl::InlinedVector* end, absl::InlinedVector* strides, + absl::InlinedVector* begin_expr, + absl::InlinedVector* end_expr, StridedSliceShapeSpec* shape_spec) { // Validate with PartialTensorShape output PartialTensorShape partial_processing_shape, partial_final_shape; @@ -458,7 +548,8 @@ absl::Status ValidateStridedSliceOp( begin_tensor, end_tensor, strides_tensor, input_shape, begin_mask_spec, end_mask_spec, ellipsis_mask, new_axis_mask, shrink_axis_mask, &partial_processing_shape, &partial_final_shape, is_identity, - is_simple_slice, slice_dim0, begin, end, strides, shape_spec)); + is_simple_slice, slice_dim0, begin, end, strides, begin_expr, end_expr, + shape_spec)); // Verify that the output shapes are fully known if (!partial_processing_shape.AsTensorShape(processing_shape) || diff --git a/tensorflow/core/util/strided_slice_op.h b/tensorflow/core/util/strided_slice_op.h index 356b77a2a0f5b6..3e55d9bb384481 100644 --- a/tensorflow/core/util/strided_slice_op.h +++ b/tensorflow/core/util/strided_slice_op.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "xla/shape.h" namespace tensorflow { @@ -74,6 +75,8 @@ absl::Status ValidateStridedSliceOp( absl::InlinedVector* begin, absl::InlinedVector* end, absl::InlinedVector* strides, + absl::InlinedVector* begin_expr = nullptr, + absl::InlinedVector* end_expr = nullptr, StridedSliceShapeSpec* shape_spec = nullptr); // Same as above, but the outputs are TensorShape, not PartialTensorShape @@ -87,6 +90,8 @@ absl::Status ValidateStridedSliceOp( absl::InlinedVector* begin, absl::InlinedVector* end, absl::InlinedVector* strides, + absl::InlinedVector* begin_expr = nullptr, + absl::InlinedVector* end_expr = nullptr, StridedSliceShapeSpec* shape_spec = nullptr); // Simple class for determining if it is possible to broadcast a tensor to a diff --git a/third_party/xla/xla/backends/cpu/codegen/kernel_api_ir_builder.cc b/third_party/xla/xla/backends/cpu/codegen/kernel_api_ir_builder.cc index 30f64818209cd4..5c5188ce33b5ce 100644 --- a/third_party/xla/xla/backends/cpu/codegen/kernel_api_ir_builder.cc +++ b/third_party/xla/xla/backends/cpu/codegen/kernel_api_ir_builder.cc @@ -533,7 +533,7 @@ llvm_ir::IrArray KernelApiIrBuilder::EmitKernelArgument( const llvm::DataLayout& data_layout = llvm_module->getDataLayout(); int64_t pointer_size = data_layout.getTypeStoreSize(builder.getPtrTy()); int64_t byte_size = ShapeUtil::ByteSizeOf(shape, pointer_size); - if (shape.outer_multiplier() < 0) + if (!shape.has_dynamic_expr()) llvm_ir::SetDereferenceableMetadataForLoad(data,byte_size); // All buffers pointers passed to host kernels are expected to be invariant diff --git a/third_party/xla/xla/hlo/builder/lib/approx_topk.cc b/third_party/xla/xla/hlo/builder/lib/approx_topk.cc index 5b0f223302f4ed..9cb5fa5eefa2b8 100644 --- a/third_party/xla/xla/hlo/builder/lib/approx_topk.cc +++ b/third_party/xla/xla/hlo/builder/lib/approx_topk.cc @@ -131,10 +131,10 @@ XlaOp AggregateToTopKBuilder(XlaBuilder* builder, reduction_computation, {reduction_dim}); Shape op_shape = operands_shapes[0]; op_shape.set_dimensions(reduction_dim, 1); - auto top1_vals = - Reshape(GetTupleElement(val_args, 0), op_shape.dimensions()); - auto top1_args = - Reshape(GetTupleElement(val_args, 1), op_shape.dimensions()); + auto top1_vals = Reshape(GetTupleElement(val_args, 0), + op_shape.dimensions(), op_shape.expressions()); + auto top1_args = Reshape(GetTupleElement(val_args, 1), + op_shape.dimensions(), op_shape.expressions()); return Tuple(builder, {top1_vals, top1_args}); } diff --git a/third_party/xla/xla/hlo/builder/lib/arithmetic.cc b/third_party/xla/xla/hlo/builder/lib/arithmetic.cc index 84908846dae705..7bb8694351ce92 100644 --- a/third_party/xla/xla/hlo/builder/lib/arithmetic.cc +++ b/third_party/xla/xla/hlo/builder/lib/arithmetic.cc @@ -154,8 +154,9 @@ XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min) { int64_t dimension_size = input_shape.dimensions(axis); auto index_type = dimension_size <= INT32_MAX ? S32 : output_type; XlaOp index_init_value = Zero(builder, index_type); - auto iota_shape = - ShapeUtil::MakeShape(index_type, input_shape.dimensions()); + auto iota_shape = ShapeUtil::MakeShape(index_type, input_shape.dimensions(), + input_shape.expressions()); + XlaOp iota = Iota(builder, iota_shape, axis); XlaComputation reducer = CreateMinMaxComputation( diff --git a/third_party/xla/xla/hlo/builder/lib/broadcast.cc b/third_party/xla/xla/hlo/builder/lib/broadcast.cc index 1baec363b53a35..938b9e1c4df118 100644 --- a/third_party/xla/xla/hlo/builder/lib/broadcast.cc +++ b/third_party/xla/xla/hlo/builder/lib/broadcast.cc @@ -31,8 +31,9 @@ limitations under the License. namespace xla { -absl::StatusOr BroadcastTo(XlaOp input, - absl::Span output_dims) { +absl::StatusOr BroadcastTo( + XlaOp input, absl::Span output_dims, + absl::Span output_exprs) { XlaBuilder* builder = input.builder(); TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); absl::Span input_dims = input_shape.dimensions(); @@ -79,15 +80,37 @@ absl::StatusOr BroadcastTo(XlaOp input, } TF_RET_CHECK(input_it == input_dims.rend()); + absl::Span input_exprs = input_shape.expressions(); + std::vector broadcast_exprs; + auto input_et = input_exprs.rbegin(); + for (auto output_et = output_exprs.rbegin(); output_et != output_exprs.rend(); + ++output_et) { + if (input_et != input_exprs.rend()) { + if (*(*output_et) == *(*input_et) || + (*input_et)->is_constant() && (*input_et)->get_val() == 1) { + broadcast_exprs.push_back(*output_et); + } else if (!(*(*output_et) == *(*input_et))) { + broadcast_exprs.push_back(*input_et); + broadcast_exprs.push_back((**output_et / **input_et)->s()); + } + ++input_et; + } else { + broadcast_exprs.push_back(*output_et); + } + } + absl::c_reverse(broadcast_dims); int broadcast_shape_size = broadcast_shape.size(); for (int64_t& broadcast_dim : broadcast_dims) { broadcast_dim = broadcast_shape_size - broadcast_dim - 1; } absl::c_reverse(broadcast_shape); - XlaOp output = BroadcastInDim(input, broadcast_shape, broadcast_dims); + absl::c_reverse(broadcast_exprs); + + XlaOp output = + BroadcastInDim(input, broadcast_shape, broadcast_dims, broadcast_exprs); if (broadcast_shape != output_dims) { - output = Reshape(output, output_dims); + output = Reshape(output, output_dims, output_exprs); } return output; } diff --git a/third_party/xla/xla/hlo/builder/lib/broadcast.h b/third_party/xla/xla/hlo/builder/lib/broadcast.h index 86cf39f64ddc82..b4ccc51a40d4fd 100644 --- a/third_party/xla/xla/hlo/builder/lib/broadcast.h +++ b/third_party/xla/xla/hlo/builder/lib/broadcast.h @@ -27,8 +27,9 @@ namespace xla { // Broadcasts 'input' up to shape 'output_dims', using TensorFlow broadcasting // rules. Supports broadcasting a dimension of size x to size x*y, i.e., tiling. -absl::StatusOr BroadcastTo(XlaOp input, - absl::Span output_dims); +absl::StatusOr BroadcastTo( + XlaOp input, absl::Span output_dims, + absl::Span output_exprs = {}); } // namespace xla diff --git a/third_party/xla/xla/hlo/builder/lib/matrix.cc b/third_party/xla/xla/hlo/builder/lib/matrix.cc index e9fe29ea83ee6b..3bd051eb9f582f 100644 --- a/third_party/xla/xla/hlo/builder/lib/matrix.cc +++ b/third_party/xla/xla/hlo/builder/lib/matrix.cc @@ -209,7 +209,8 @@ XlaOp SetMatrixDiagonal(XlaOp matrix, XlaOp diag, int k) { } return Select(GetDiagonalMask(matrix, k), - BroadcastInDim(diag, shape.dimensions(), broadcast_dims), + BroadcastInDim(diag, shape.dimensions(), broadcast_dims, + shape.expressions()), matrix); }); } @@ -341,18 +342,22 @@ xla::XlaOp EinsumInverseDiagonal(XlaOp x, absl::Span config) { } TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); std::vector broadcast_sizes; + std::vector broadcast_exprs; int64_t x_dim = 0; for (auto label = config.begin(); label != config.end(); ++label) { auto first_label = absl::c_find(config, *label); if (first_label == label) { broadcast_sizes.push_back(x_shape.dimensions(x_dim)); + broadcast_exprs.push_back(x_shape.expressions(x_dim)); ++x_dim; } else { broadcast_sizes.push_back( broadcast_sizes[first_label - config.begin()]); + broadcast_exprs.push_back( + broadcast_exprs[first_label - config.begin()]); } } - x = BroadcastInDim(x, broadcast_sizes, labels->at(2)); + x = BroadcastInDim(x, broadcast_sizes, labels->at(2), broadcast_exprs); return EinsumDiagonalMask(x, config); }); } @@ -568,16 +573,20 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span x_config, int64_t dot_dim = 0; std::vector new_dims; + std::vector new_exprs; new_dims.reserve(output_rank); + new_exprs.reserve(output_rank); TF_ASSIGN_OR_RETURN(Shape dot_shape, builder->GetShape(dot)); for (auto d : output_config) { if (is_output_only(d)) { new_dims.push_back(1); + new_exprs.push_back(DynExpr::one); } else { new_dims.push_back(dot_shape.dimensions(dot_dim)); + new_exprs.push_back(dot_shape.expressions(dot_dim)); } } - return Reshape(dot, new_dims); + return Reshape(dot, new_dims, new_exprs); }); } diff --git a/third_party/xla/xla/hlo/builder/lib/prng.cc b/third_party/xla/xla/hlo/builder/lib/prng.cc index 661981f19c17ea..a18fa3cfe54e9c 100644 --- a/third_party/xla/xla/hlo/builder/lib/prng.cc +++ b/third_party/xla/xla/hlo/builder/lib/prng.cc @@ -41,8 +41,9 @@ namespace xla { xla::XlaOp ConcatScalars(xla::XlaBuilder* builder, absl::Span scalars) { std::vector vectors; - absl::c_transform(scalars, std::back_inserter(vectors), - [](xla::XlaOp x) { return xla::Reshape(x, {1}); }); + absl::c_transform(scalars, std::back_inserter(vectors), [](xla::XlaOp x) { + return xla::Reshape(x, {1}, {xla::DynExpr::one}); + }); return ConcatInDim(builder, vectors, 0); } @@ -154,7 +155,8 @@ std::pair GetThreeFryInputsAndUpdatedState( XlaBuilder* builder = initial_state.builder(); auto u64_shape = ShapeUtil::MakeShape(U64, shape.dimensions()); // initial_state is an R1, so reshape it to a scalar. - auto input_u64 = Broadcast(Reshape(initial_state, {}), shape.dimensions()); + auto input_u64 = Broadcast(Reshape(initial_state, {}), shape.dimensions(), + shape.expressions()); int64_t trailing_dims_product = 1; for (int64_t i = shape.dimensions().size() - 1; i >= 0; --i) { if (shape.dimensions(i) < 2) { @@ -245,8 +247,12 @@ XlaOp CombineShapePair(absl::Span pair, original_shape.dimensions(shape_pair.split_dim); std::vector reshape_dims(original_shape.dimensions().begin(), original_shape.dimensions().end()); + std::vector reshape_exprs(original_shape.expressions().begin(), + original_shape.expressions().end()); reshape_dims[shape_pair.split_dim] = RoundUpTo(pre_split_size, 2); - result = Reshape(result, reshape_dims); + reshape_exprs[shape_pair.split_dim] = + DynExpr::_(RoundUpTo(pre_split_size, 2)); + result = Reshape(result, reshape_dims, reshape_exprs); if (reshape_dims[shape_pair.split_dim] != pre_split_size) { result = Slice(result, std::vector(original_shape.dimensions().size(), 0), @@ -453,11 +459,11 @@ RngOutput PhiloxRngBit32(XlaOp op_key, XlaOp initial_state, } XlaOp numbers = ConcatInDim(builder, {bits[0], bits[1], bits[2], bits[3]}, /*dimension=*/1); - numbers = Reshape(numbers, {bits_len * 4}); + numbers = Reshape(numbers, {bits_len * 4}, {}); numbers = Slice(numbers, /*start_indices=*/{0}, /*limit_indices=*/{num_elems}, /*strides=*/{1}); - return {Reshape(numbers, shape.dimensions()), new_state}; + return {Reshape(numbers, shape.dimensions(), shape.expressions()), new_state}; } // Generates an array of primitive type U16 with the given shape containing @@ -507,7 +513,7 @@ RngOutput PhiloxRngBit64(XlaOp op_key, XlaOp initial_state, numbers = Slice(numbers, /*start_indices=*/{0}, /*limit_indices=*/{num_elems}, /*strides=*/{1}); - return {Reshape(numbers, shape.dimensions()), new_state}; + return {Reshape(numbers, shape.dimensions(), shape.expressions()), new_state}; } XlaOp ConvertRandomBitsToUniformFloatingPoint(XlaOp bits, XlaOp minval, diff --git a/third_party/xla/xla/hlo/builder/lib/slicing.cc b/third_party/xla/xla/hlo/builder/lib/slicing.cc index ae0f6b987497a1..479d743a3fbc0c 100644 --- a/third_party/xla/xla/hlo/builder/lib/slicing.cc +++ b/third_party/xla/xla/hlo/builder/lib/slicing.cc @@ -168,13 +168,16 @@ XlaOp TorchGather(XlaOp input, XlaOp index, int64_t dim, bool sparse) { std::vector index_broadcast_dims; std::vector input_broadcast_dims; std::vector sizes; + std::vector expressions; sizes.reserve(index_shape.dimensions().size()); + expressions.reserve(index_shape.expressions().size()); for (int64_t i = 0; i < index_shape.dimensions().size(); ++i) { if (i < dim) { input_broadcast_dims.push_back(i); index_broadcast_dims.push_back(i); } else if (i == dim) { sizes.push_back(input_shape.dimensions(i)); + expressions.push_back(input_shape.expressions(i)); input_broadcast_dims.push_back(i); index_broadcast_dims.push_back(i + 1); } else { @@ -182,15 +185,18 @@ XlaOp TorchGather(XlaOp input, XlaOp index, int64_t dim, bool sparse) { index_broadcast_dims.push_back(i + 1); } sizes.push_back(index_shape.dimensions(i)); + expressions.push_back(index_shape.expressions(i)); } - auto mask = Eq( - BroadcastInDim(index, sizes, index_broadcast_dims), - Iota(builder, ShapeUtil::MakeShape(index_shape.element_type(), sizes), - dim)); + auto mask = + Eq(BroadcastInDim(index, sizes, index_broadcast_dims, expressions), + Iota(builder, + ShapeUtil::MakeShape(index_shape.element_type(), sizes, + expressions), + dim)); auto masked_input = Select( - mask, BroadcastInDim(input, sizes, input_broadcast_dims), - Zeros(builder, - ShapeUtil::MakeShape(input_shape.element_type(), sizes))); + mask, BroadcastInDim(input, sizes, input_broadcast_dims, expressions), + Zeros(builder, ShapeUtil::MakeShape(input_shape.element_type(), sizes, + expressions))); return Reduce(masked_input, Zero(builder, input_shape.element_type()), CreateScalarIdentityWithZeroComputation( input_shape.element_type(), builder), @@ -203,7 +209,8 @@ XlaOp TorchGather(XlaOp input, XlaOp index, int64_t dim, bool sparse) { to_concat.reserve(input_shape.dimensions().size()); for (int64_t i = 0; i < input_shape.dimensions().size(); ++i) { if (i == dim) { - to_concat.push_back(Reshape(index, index_shape.dimensions())); + to_concat.push_back(Reshape(index, index_shape.dimensions(), + index_shape.expressions())); } else { to_concat.push_back(Iota(builder, index_shape, i)); } @@ -229,27 +236,33 @@ XlaOp TorchScatterDense(XlaOp input, XlaOp index, XlaOp src, int64_t dim, TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); std::vector index_broadcast_dims; std::vector sizes; + std::vector expressions; const auto rank = index_shape.dimensions().size(); sizes.reserve(rank + 1); + expressions.reserve(rank + 1); for (int64_t i = 0; i < index_shape.dimensions().size(); ++i) { if (i < dim) { index_broadcast_dims.push_back(i); } else { if (i == dim) { sizes.push_back(input_shape.dimensions(i)); + expressions.push_back(input_shape.expressions(i)); } index_broadcast_dims.push_back(i + 1); } sizes.push_back(index_shape.dimensions(i)); + expressions.push_back(index_shape.expressions(i)); } auto mask = - Eq(BroadcastInDim(index, sizes, index_broadcast_dims), + Eq(BroadcastInDim(index, sizes, index_broadcast_dims, expressions), Iota(builder, - ShapeUtil::MakeShape(index_shape.element_type(), sizes), dim)); - auto masked_src = - Select(mask, BroadcastInDim(src, sizes, index_broadcast_dims), - Zeros(builder, - ShapeUtil::MakeShape(input_shape.element_type(), sizes))); + ShapeUtil::MakeShape(index_shape.element_type(), sizes, + expressions), + dim)); + auto masked_src = Select( + mask, BroadcastInDim(src, sizes, index_broadcast_dims, expressions), + Zeros(builder, ShapeUtil::MakeShape(input_shape.element_type(), sizes, + expressions))); return combiner( input, @@ -287,7 +300,8 @@ XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64_t dim, for (int64_t batch_dim = 0; batch_dim < batch_dims; ++batch_dim) { to_concat.push_back(Iota(builder, iota_shape, batch_dim)); } - to_concat.push_back(Reshape(index, index_shape.dimensions())); + to_concat.push_back( + Reshape(index, index_shape.dimensions(), index_shape.expressions())); index = ConcatInDim(builder, to_concat, gather_dnums.index_vector_dim()); } for (int64_t i = 0; i < input_shape.dimensions().size(); ++i) { diff --git a/third_party/xla/xla/hlo/builder/lib/svd.cc b/third_party/xla/xla/hlo/builder/lib/svd.cc index 561c107ba8085a..9254085e69a85e 100644 --- a/third_party/xla/xla/hlo/builder/lib/svd.cc +++ b/third_party/xla/xla/hlo/builder/lib/svd.cc @@ -152,9 +152,9 @@ absl::StatusOr HouseRow( auto beta = Div(ScalarLike(v_0j, 2.0), (Square(Div(sigma, v_0j, broadcast_dims)) + one)); - v = Select( - BroadcastInDim(Lt(sigma, eps), x_shape.dimensions(), broadcast_dims), v, - v / v_0j); + v = Select(BroadcastInDim(Lt(sigma, eps), x_shape.dimensions(), + broadcast_dims, x_shape.expressions()), + v, v / v_0j); v = Select(Eq(idx, j), zeros + one, v); beta = Select(Lt(Add(sigma, ZerosLike(beta), broadcast_dims), eps), @@ -219,9 +219,9 @@ absl::StatusOr HouseCol( auto beta = Div(ScalarLike(v_0i, 2.0), (Square(Div(sigma, v_0i, broadcast_dims)) + one)); - v = Select( - BroadcastInDim(Lt(sigma, eps), x_shape.dimensions(), broadcast_dims), v, - v / v_0i); + v = Select(BroadcastInDim(Lt(sigma, eps), x_shape.dimensions(), + broadcast_dims, x_shape.expressions()), + v, v / v_0i); v = Select(Eq(idx, i), zeros + one, v); beta = Select(Lt(Add(sigma, ZerosLike(beta), broadcast_dims), eps), @@ -582,11 +582,11 @@ absl::StatusOr ComputeToleranceComparison(XlaOp w, XlaOp epsilon) { diag = Select(Lt(diag, ZerosLike(diag)), -diag, diag); std::vector broadcasted_dims(num_dims - 1); std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0); - auto broadcast_to_rows = - BroadcastInDim(diag, shape.dimensions(), broadcasted_dims); + auto broadcast_to_rows = BroadcastInDim( + diag, shape.dimensions(), broadcasted_dims, shape.expressions()); broadcasted_dims.back() = num_dims - 1; - auto broadcast_to_columns = - BroadcastInDim(diag, shape.dimensions(), broadcasted_dims); + auto broadcast_to_columns = BroadcastInDim( + diag, shape.dimensions(), broadcasted_dims, shape.expressions()); // Compute tolerance = w_{i,i} * w_{j,j} * epsilon^2 // Use at least F32 precision to avoid precision issues with small denormal. XlaOp tolerance; @@ -745,6 +745,7 @@ absl::StatusOr SortBySingularValuesAndPostProcessing( TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(result.d)); const int64_t num_dims = shape.dimensions().size(); auto dimensions = shape.dimensions(); + auto expressions = shape.expressions(); const int64_t m = ShapeUtil::GetDimension(shape, -2); const int64_t n = ShapeUtil::GetDimension(shape, -1); @@ -763,7 +764,7 @@ absl::StatusOr SortBySingularValuesAndPostProcessing( d = Select(Ge(d, zeros), d, -d); result.v = Mul(result.v, sign, broadcast_dims); - d = BroadcastInDim(d, dimensions, broadcast_dims); + d = BroadcastInDim(d, dimensions, broadcast_dims, expressions); // As m >= n, only first n column vectors need to be permuted, and the rest of // m - n vectors are appended after the sorting is done. diff --git a/third_party/xla/xla/hlo/builder/xla_builder.cc b/third_party/xla/xla/hlo/builder/xla_builder.cc index 96000563190d77..de4a2b10a3e49c 100644 --- a/third_party/xla/xla/hlo/builder/xla_builder.cc +++ b/third_party/xla/xla/hlo/builder/xla_builder.cc @@ -1017,22 +1017,25 @@ absl::StatusOr XlaBuilder::AddBroadcastSequence( Shape broadcast_shape = ShapeUtil::ChangeElementType(output_shape, operand_shape->element_type()); - // Do explicit broadcast for scalar. - if (ShapeUtil::IsScalar(*operand_shape)) { - return InDimBroadcast(ShapeUtil::MakeStaticShape(broadcast_shape), operand, - {}); - } + // Do explicit broadcast for scalar. + if (ShapeUtil::IsScalar(*operand_shape)) { + return InDimBroadcast(ShapeUtil::MakeStaticShape(broadcast_shape), + operand, {}); + } // Do explicit broadcast for degenerate broadcast. std::vector broadcast_dimensions; std::vector reshaped_dimensions; std::vector reshaped_dynamic_dimensions; + std::vector reshaped_expressions; for (int i = 0; i < operand_shape->dimensions().size(); i++) { if (operand_shape->dimensions(i) == output_shape.dimensions(i)) { broadcast_dimensions.push_back(i); reshaped_dimensions.push_back(operand_shape->dimensions(i)); reshaped_dynamic_dimensions.push_back( operand_shape->is_dynamic_dimension(i)); + reshaped_expressions.push_back( + operand_shape->expressions(i)); } else { TF_RET_CHECK(operand_shape->dimensions(i) == 1 && operand_shape->is_static_dimension(i)) @@ -1046,7 +1049,7 @@ absl::StatusOr XlaBuilder::AddBroadcastSequence( Shape reshaped_shape = ShapeUtil::MakeShape(operand_shape->element_type(), reshaped_dimensions, - reshaped_dynamic_dimensions); + reshaped_dynamic_dimensions, reshaped_expressions); // Eliminate the size one dimensions. // The added reshape reduces the rank of the tensor. Hence we cannot directly @@ -1094,15 +1097,20 @@ absl::StatusOr BroadcastToTargetRank( return origin; } - // Update target_size with origin sizes using broadcast_dimensions + // Update target_size and target_exp with origin sizes and expressions using + // broadcast_dimensions absl::Span target_dimensions = target_shape.dimensions(); + absl::Span target_expressions = target_shape.expressions(); std::vector target_size{target_dimensions.begin(), target_dimensions.end()}; + std::vector target_exp{target_expressions.begin(), + target_expressions.end()}; for (int64_t origin_dim = 0; origin_dim < origin_rank; origin_dim++) { int64_t target_dim = broadcast_dimensions[origin_dim]; target_size[target_dim] = origin_shape.dimensions(origin_dim); + target_exp[target_dim] = origin_shape.expressions(origin_dim); } - return BroadcastInDim(origin, target_size, broadcast_dimensions); + return BroadcastInDim(origin, target_size, broadcast_dimensions, target_exp); } // Extract the `num_dims` counts of dimension sizes from the `op`. First, @@ -1120,7 +1128,7 @@ absl::StatusOr> ExtractDimensionSizesAndPadOnesToLeft( ? ConstantR1( /*builder=*/builder, /*values=*/{static_cast(op_shape->dimensions(i))}) - : Reshape(GetDimensionSize(op, i), {1})); + : Reshape(GetDimensionSize(op, i), {1}, {xla::DynExpr::one})); } return op_dims; } @@ -1142,7 +1150,8 @@ absl::StatusOr BroadcastScalarToOutputShapeWithUnbounded( ? ConstantR1( /*builder=*/builder, /*values=*/{static_cast(output_shape.dimensions(i))}) - : Reshape(GetDimensionSize(output, i), {1}); + : Reshape(GetDimensionSize(output, i), {1}, + {xla::DynExpr::one}); } return MhloDynamicBroadcastInDim( scalar, /*output_dimensions=*/ConcatInDim(builder, output_sizes, 0), {}, @@ -1525,12 +1534,13 @@ XlaOp XlaBuilder::Parameter( } XlaOp XlaBuilder::Broadcast(XlaOp operand, - absl::Span broadcast_sizes) { + absl::Span broadcast_sizes, + absl::Span broadcast_exprs) { return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); - TF_ASSIGN_OR_RETURN( - const Shape& shape, - ShapeInference::InferBroadcastShape(*operand_shape, broadcast_sizes)); + TF_ASSIGN_OR_RETURN(const Shape& shape, + ShapeInference::InferBroadcastShape( + *operand_shape, broadcast_sizes, broadcast_exprs)); // The client-level broadcast op just appends dimensions on the left (adds // lowest numbered dimensions). The HLO broadcast instruction is more @@ -1550,14 +1560,15 @@ XlaOp XlaBuilder::Broadcast(XlaOp operand, XlaOp XlaBuilder::BroadcastInDim( XlaOp operand, absl::Span out_dim_size, - absl::Span broadcast_dimensions) { + absl::Span broadcast_dimensions, + absl::Span out_dim_exp) { return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); // Output shape, in the case of degenerate broadcast, the out_dim_size is // not necessarily the same as the dimension sizes of the output shape. - TF_ASSIGN_OR_RETURN(auto output_shape, - ShapeUtil::MakeValidatedShape( - operand_shape->element_type(), out_dim_size)); + TF_ASSIGN_OR_RETURN(auto output_shape, ShapeUtil::MakeValidatedShape( + operand_shape->element_type(), + out_dim_size, out_dim_exp)); TF_RET_CHECK(!output_shape.is_unbounded_dynamic()) << "BroadcastInDim output must shape be static or bounded dynamic " << ShapeUtil::HumanString(output_shape); @@ -1584,6 +1595,18 @@ XlaOp XlaBuilder::BroadcastInDim( .status()); std::vector in_dim_size(out_dim_size.begin(), out_dim_size.end()); std::vector in_dim_dynamic(out_dim_size.size(), false); + std::vector in_expressions(out_dim_exp.begin(), + out_dim_exp.end()); + + // If out_dim_exp is empty just make expressions out of the static + // dimensions. + if (out_dim_exp.empty()) { + in_expressions.reserve(out_dim_size.size()); + std::transform(out_dim_size.begin(), out_dim_size.end(), + std::back_inserter(in_expressions), + [](int d) { return DynExpr::_(d); }); + } + for (int i = 0; i < broadcast_rank; i++) { in_dim_size[broadcast_dimensions[i]] = (operand_shape->is_unbounded_dynamic_dimension(i)) @@ -1591,9 +1614,12 @@ XlaOp XlaBuilder::BroadcastInDim( : operand_shape->dimensions(i); in_dim_dynamic[broadcast_dimensions[i]] = operand_shape->is_bounded_dynamic_dimension(i); + in_expressions[broadcast_dimensions[i]] = + operand_shape->expressions(i); } - const auto& in_dim_shape = ShapeUtil::MakeShape( - operand_shape->element_type(), in_dim_size, in_dim_dynamic); + const auto& in_dim_shape = + ShapeUtil::MakeShape(operand_shape->element_type(), in_dim_size, + in_dim_dynamic, in_expressions); TF_ASSIGN_OR_RETURN( XlaOp in_dim_broadcast, InDimBroadcast(in_dim_shape, operand, broadcast_dimensions)); @@ -1637,6 +1663,21 @@ XlaOp XlaBuilder::Slice(XlaOp operand, absl::Span start_indices, }); } +XlaOp XlaBuilder::Slice(XlaOp operand, absl::Span start_indices, + absl::Span limit_indices, + absl::Span start_exprs, + absl::Span limit_exprs, + absl::Span strides) { + return ReportErrorOrReturn([&]() -> absl::StatusOr { + TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); + TF_ASSIGN_OR_RETURN( + Shape shape, ShapeInference::InferSliceShape( + *operand_shape, start_indices, limit_indices, strides, + start_exprs, limit_exprs)); + return SliceInternal(shape, operand, start_indices, limit_indices, strides); + }); +} + absl::StatusOr XlaBuilder::SliceInternal( const Shape& shape, XlaOp operand, absl::Span start_indices, absl::Span limit_indices, @@ -1668,9 +1709,33 @@ XlaOp XlaBuilder::SliceInDim(XlaOp operand, int64_t start_index, }); } +XlaOp XlaBuilder::SliceInDim(XlaOp operand, int64_t start_index, + int64_t limit_index, DynExpr* start_expr, + DynExpr* limit_expr, int64_t stride, + int64_t dimno) { + return ReportErrorOrReturn([&]() -> absl::StatusOr { + TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand)); + std::vector starts(shape->dimensions().size(), 0); + std::vector limits(shape->dimensions().begin(), + shape->dimensions().end()); + std::vector start_exprs(shape->dimensions().size(), + DynExpr::zero); + std::vector limit_exprs(shape->expressions().begin(), + shape->expressions().end()); + std::vector strides(shape->dimensions().size(), 1); + starts[dimno] = start_index; + limits[dimno] = limit_index; + start_exprs[dimno] = start_expr; + limit_exprs[dimno] = limit_expr; + strides[dimno] = stride; + return Slice(operand, starts, limits, start_exprs, limit_exprs, strides); + }); +} + XlaOp XlaBuilder::DynamicSlice(XlaOp operand, absl::Span start_indices, - absl::Span slice_sizes) { + absl::Span slice_sizes, + absl::Span slice_exprs) { return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); std::vector start_indices_shape_ptrs; @@ -1679,9 +1744,9 @@ XlaOp XlaBuilder::DynamicSlice(XlaOp operand, absl::c_transform(start_indices_shapes, std::back_inserter(start_indices_shape_ptrs), [](const Shape& shape) { return &shape; }); - TF_ASSIGN_OR_RETURN(Shape shape, - ShapeInference::InferDynamicSliceShape( - *operand_shape, start_indices_shapes, slice_sizes)); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDynamicSliceShape( + *operand_shape, start_indices_shapes, + slice_sizes, slice_exprs)); return DynamicSliceInternal(shape, operand, start_indices, slice_sizes); }); } @@ -1698,6 +1763,7 @@ absl::StatusOr XlaBuilder::DynamicSliceInternal( std::vector operands = {operand}; operands.insert(operands.end(), start_indices.begin(), start_indices.end()); + return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, operands); } @@ -1794,9 +1860,22 @@ XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span dimensions, int64_t inferred_dimension) { return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); - TF_ASSIGN_OR_RETURN(const Shape shape, - ShapeInference::InferReshapeShape( - *operand_shape, dimensions, inferred_dimension)); + TF_ASSIGN_OR_RETURN(const Shape shape, ShapeInference::InferReshapeShape( + *operand_shape, dimensions, + inferred_dimension, {})); + return ReshapeInternal(shape, operand, inferred_dimension); + }); +} + +XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span dimensions, + absl::Span expressions, + int64_t inferred_dimension) { + return ReportErrorOrReturn([&]() -> absl::StatusOr { + TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); + TF_ASSIGN_OR_RETURN( + const Shape shape, + ShapeInference::InferReshapeShape(*operand_shape, dimensions, + inferred_dimension, expressions)); return ReshapeInternal(shape, operand, inferred_dimension); }); } @@ -1811,7 +1890,8 @@ XlaOp XlaBuilder::Reshape(const Shape& shape, XlaOp operand, XlaOp XlaBuilder::DynamicReshape(XlaOp operand, absl::Span dim_sizes, absl::Span new_size_bounds, - const std::vector& dims_are_dynamic) { + const std::vector& dims_are_dynamic, + absl::Span expressions) { return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); std::vector dim_size_shape_ptrs; @@ -1820,10 +1900,11 @@ XlaOp XlaBuilder::DynamicReshape(XlaOp operand, absl::c_transform(dim_size_shapes, std::back_inserter(dim_size_shape_ptrs), [](const Shape& shape) { return &shape; }); - TF_ASSIGN_OR_RETURN(const Shape shape, - ShapeInference::InferDynamicReshapeShape( - *operand_shape, dim_size_shape_ptrs, - new_size_bounds, dims_are_dynamic)); + TF_ASSIGN_OR_RETURN( + const Shape shape, + ShapeInference::InferDynamicReshapeShape( + *operand_shape, dim_size_shape_ptrs, new_size_bounds, + dims_are_dynamic, expressions)); TF_RETURN_IF_ERROR(first_error_); std::vector operands; operands.reserve(1 + dim_sizes.size()); @@ -1865,17 +1946,21 @@ XlaOp XlaBuilder::Collapse(XlaOp operand, VLOG(3) << "dims to collapse: " << absl::StrJoin(dimensions, ","); std::vector new_sizes; + std::vector new_exprs; for (int i = 0; i < original_shape->dimensions().size(); ++i) { if (i <= dimensions.front() || i > dimensions.back()) { new_sizes.push_back(original_shape->dimensions(i)); + new_exprs.push_back(original_shape->expressions(i)); } else { new_sizes.back() *= original_shape->dimensions(i); + new_exprs.back() = + *(new_exprs.back()) * *(original_shape->expressions(i)); } } VLOG(3) << "new sizes: [" << absl::StrJoin(new_sizes, ",") << "]"; - return Reshape(operand, new_sizes); + return Reshape(operand, new_sizes, new_exprs); }); } @@ -3912,12 +3997,14 @@ XlaOp XlaBuilder::AllToAllArray( return all_to_all; } DimensionVector sizes; + std::vector expressions; const bool is_unbounded = operand_shape->is_unbounded_dynamic(); std::vector dynamic_sizes; auto GetR1DimensionSizeOrConstant = [&](XlaOp operand, int64_t dimension) -> XlaOp { if (operand_shape->is_unbounded_dynamic_dimension(dimension)) { - return Reshape(GetDimensionSize(operand, dimension), {1}); + return Reshape(GetDimensionSize(operand, dimension), {1}, + {DynExpr::one}); } return ConstantR1( this, {static_cast(operand_shape->dimensions(dimension))}); @@ -3927,15 +4014,19 @@ XlaOp XlaBuilder::AllToAllArray( for (int64_t i = 0; i < operand_shape->dimensions().size(); ++i) { if (i != split_dimension) { sizes.push_back(operand_shape->dimensions(i)); + expressions.push_back(operand_shape->expressions(i)); if (is_unbounded) { dynamic_sizes.push_back(GetR1DimensionSizeOrConstant(operand, i)); } continue; } sizes.push_back(split_count); + expressions.push_back(DynExpr::_(split_count)); sizes.push_back(operand_shape->is_unbounded_dynamic_dimension(i) ? Shape::kUnboundedSize : operand_shape->dimensions(i) / split_count); + expressions.push_back( + (*operand_shape->expressions(i) / split_count)->s()); if (is_unbounded) { dynamic_sizes.push_back(r1_split_count); @@ -3955,11 +4046,11 @@ XlaOp XlaBuilder::AllToAllArray( TF_ASSIGN_OR_RETURN( const Shape shape, ShapeUtil::MakeValidatedShape(all_to_all_shape.element_type(), sizes, - dynamic_dimensions)); + dynamic_dimensions, expressions)); all_to_all = MhloDynamicReshape(all_to_all, ConcatInDim(dynamic_sizes, 0), shape); } else { - all_to_all = Reshape(all_to_all, sizes); + all_to_all = Reshape(all_to_all, sizes, expressions); } std::vector permutation; @@ -4986,16 +5077,18 @@ XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal) { return builder->ConstantLiteral(literal); } -XlaOp Broadcast(const XlaOp operand, - absl::Span broadcast_sizes) { - return operand.builder()->Broadcast(operand, broadcast_sizes); +XlaOp Broadcast(const XlaOp operand, absl::Span broadcast_sizes, + absl::Span broadcast_exprs) { + return operand.builder()->Broadcast(operand, broadcast_sizes, + broadcast_exprs); } XlaOp BroadcastInDim(const XlaOp operand, absl::Span out_dim_size, - absl::Span broadcast_dimensions) { + absl::Span broadcast_dimensions, + absl::Span out_dim_exp) { return operand.builder()->BroadcastInDim(operand, out_dim_size, - broadcast_dimensions); + broadcast_dimensions, out_dim_exp); } XlaOp MhloDynamicReshape(const XlaOp operand, const XlaOp output_shape, @@ -5030,21 +5123,29 @@ XlaOp Reshape(const XlaOp operand, absl::Span dimensions) { return operand.builder()->Reshape(operand, dimensions); } +XlaOp Reshape(const XlaOp operand, absl::Span dimensions, + absl::Span expressions) { + return operand.builder()->Reshape(operand, dimensions, expressions); +} + XlaOp Reshape(const Shape& shape, XlaOp operand) { return operand.builder()->Reshape(shape, operand); } XlaOp DynamicReshape(XlaOp operand, absl::Span dim_sizes, absl::Span new_size_bounds, - const std::vector& dims_are_dynamic) { + const std::vector& dims_are_dynamic, + absl::Span expressions) { return operand.builder()->DynamicReshape(operand, dim_sizes, new_size_bounds, - dims_are_dynamic); + dims_are_dynamic, expressions); } XlaOp ReshapeWithInferredDimension(XlaOp operand, absl::Span new_sizes, + absl::Span new_exprs, int64_t inferred_dimension) { - return operand.builder()->Reshape(operand, new_sizes, inferred_dimension); + return operand.builder()->Reshape(operand, new_sizes, new_exprs, + inferred_dimension); } XlaOp Collapse(const XlaOp operand, absl::Span dimensions) { @@ -5058,15 +5159,33 @@ XlaOp Slice(const XlaOp operand, absl::Span start_indices, strides); } +XlaOp Slice(const XlaOp operand, absl::Span start_indices, + absl::Span limit_indices, + absl::Span start_exprs, + absl::Span limit_exprs, + absl::Span strides) { + return operand.builder()->Slice(operand, start_indices, limit_indices, + start_exprs, limit_exprs, strides); +} + XlaOp SliceInDim(const XlaOp operand, int64_t start_index, int64_t limit_index, int64_t stride, int64_t dimno) { return operand.builder()->SliceInDim(operand, start_index, limit_index, stride, dimno); } +XlaOp SliceInDim(const XlaOp operand, int64_t start_index, int64_t limit_index, + DynExpr* start_expr, DynExpr* limit_expr, int64_t stride, + int64_t dimno) { + return operand.builder()->SliceInDim(operand, start_index, limit_index, + start_expr, limit_expr, stride, dimno); +} + XlaOp DynamicSlice(const XlaOp operand, absl::Span start_indices, - absl::Span slice_sizes) { - return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes); + absl::Span slice_sizes, + absl::Span slice_exprs) { + return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes, + slice_exprs); } XlaOp DynamicUpdateSlice(const XlaOp operand, const XlaOp update, diff --git a/third_party/xla/xla/hlo/builder/xla_builder.h b/third_party/xla/xla/hlo/builder/xla_builder.h index b5ede47dd41476..0c3c41e4cc59dc 100644 --- a/third_party/xla/xla/hlo/builder/xla_builder.h +++ b/third_party/xla/xla/hlo/builder/xla_builder.h @@ -519,10 +519,12 @@ class XlaBuilder { virtual XlaOp ConstantLiteral(const LiteralSlice& literal); - XlaOp Broadcast(XlaOp operand, absl::Span broadcast_sizes); + XlaOp Broadcast(XlaOp operand, absl::Span broadcast_sizes, + absl::Span broadcast_exprs = {}); XlaOp BroadcastInDim(XlaOp operand, absl::Span out_dim_size, - absl::Span broadcast_dimensions); + absl::Span broadcast_dimensions, + absl::Span out_dim_exp = {}); // This is an experimental API for creating the mhlo.dynamic_broadcast_in_dim // op from the XlaBuilder. This is only intended for export to MHLO or @@ -545,12 +547,17 @@ class XlaBuilder { XlaOp Reshape(XlaOp operand, absl::Span dimensions, int64_t inferred_dimension = -1); + XlaOp Reshape(XlaOp operand, absl::Span dimensions, + absl::Span expressions, + int64_t inferred_dimension = -1); + XlaOp Reshape(const Shape& shape, XlaOp operand, int64_t inferred_dimension = -1); XlaOp DynamicReshape(XlaOp operand, absl::Span dim_sizes, absl::Span new_size_bounds, - const std::vector& dims_are_dynamic); + const std::vector& dims_are_dynamic, + absl::Span expressions = {}); XlaOp MhloDynamicReshape(XlaOp operand, XlaOp output_shape, const Shape& shape); @@ -560,6 +567,13 @@ class XlaBuilder { XlaOp Slice(XlaOp operand, absl::Span start_indices, absl::Span limit_indices, absl::Span strides); + + XlaOp Slice(XlaOp operand, absl::Span start_indices, + absl::Span limit_indices, + absl::Span start_exprs, + absl::Span limit_exprs, + absl::Span strides); + virtual absl::StatusOr SliceInternal( const Shape& shape, XlaOp operand, absl::Span start_indices, @@ -568,8 +582,13 @@ class XlaBuilder { virtual XlaOp SliceInDim(XlaOp operand, int64_t start_index, int64_t limit_index, int64_t stride, int64_t dimno); + virtual XlaOp SliceInDim(XlaOp operand, int64_t start_index, + int64_t limit_index, DynExpr* start_expr, + DynExpr* limit_expr, int64_t stride, int64_t dimno); + XlaOp DynamicSlice(XlaOp operand, absl::Span start_indices, - absl::Span slice_sizes); + absl::Span slice_sizes, + absl::Span slice_exprs = {}); virtual absl::StatusOr DynamicSliceInternal( const Shape& shape, XlaOp operand, absl::Span start_indices, absl::Span slice_sizes); @@ -1244,11 +1263,13 @@ class XlaBuilder { const LiteralSlice& literal); friend XlaOp Broadcast(XlaOp operand, - absl::Span broadcast_sizes); + absl::Span broadcast_sizes, + absl::Span broadcast_expressions); friend XlaOp BroadcastInDim(XlaOp operand, absl::Span out_dim_size, - absl::Span broadcast_dimensions); + absl::Span broadcast_dimensions, + absl::Span out_dim_exp); friend XlaOp MhloDynamicBroadcastInDim( XlaOp operand, XlaOp output_dimensions, @@ -1265,18 +1286,22 @@ class XlaBuilder { friend XlaOp Reshape(XlaOp operand, absl::Span dimensions); + friend XlaOp Reshape(XlaOp operand, absl::Span dimensions, + absl::Span expressions); + friend XlaOp Reshape(const Shape& shape, XlaOp operand); friend XlaOp DynamicReshape(XlaOp operand, absl::Span dim_sizes, absl::Span new_size_bounds, - const std::vector& dims_are_dynamic); + const std::vector& dims_are_dynamic, + absl::Span expressions); friend XlaOp MhloDynamicReshape(XlaOp operand, XlaOp output_shape, const Shape& shape); - friend XlaOp ReshapeWithInferredDimension(XlaOp operand, - absl::Span new_sizes, - int64_t inferred_dimension); + friend XlaOp ReshapeWithInferredDimension( + XlaOp operand, absl::Span new_sizes, + absl::Span new_exprs, int64_t inferred_dimension); friend XlaOp Collapse(XlaOp operand, absl::Span dimensions); @@ -1284,12 +1309,23 @@ class XlaBuilder { absl::Span limit_indices, absl::Span strides); + friend XlaOp Slice(XlaOp operand, absl::Span start_indices, + absl::Span limit_indices, + absl::Span start_exprs, + absl::Span limit_exprs, + absl::Span strides); + friend XlaOp SliceInDim(XlaOp operand, int64_t start_index, int64_t limit_index, int64_t stride, int64_t dimno); + friend XlaOp SliceInDim(XlaOp operand, int64_t start_index, + int64_t limit_index, DynExpr* start_expr, + DynExpr* limit_expr, int64_t stride, int64_t dimno); + friend XlaOp DynamicSlice(XlaOp operand, absl::Span start_indices, - absl::Span slice_sizes); + absl::Span slice_sizes, + absl::Span slice_exprs); friend XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, absl::Span start_indices); @@ -1975,7 +2011,8 @@ XlaOp ConstantR1(XlaBuilder* builder, int64_t length, NativeT value); // The new dimensions index into copies of the operand, i.e. // // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] -XlaOp Broadcast(XlaOp operand, absl::Span broadcast_sizes); +XlaOp Broadcast(XlaOp operand, absl::Span broadcast_sizes, + absl::Span broadcast_exprs = {}); // This op broadcasts the `operand` to an output with the given `shape`. // `broadcast_dimensions` are the dimensions to be broadcasting into, i.e., the @@ -1993,7 +2030,8 @@ XlaOp Broadcast(XlaOp operand, absl::Span broadcast_sizes); // {{1 , 1}, // {2 , 2}} XlaOp BroadcastInDim(XlaOp operand, absl::Span out_dim_size, - absl::Span broadcast_dimensions); + absl::Span broadcast_dimensions, + absl::Span out_dim_exp = {}); // This is an experimental API for creating the mhlo.dynamic_broadcast_in_dim // op from the XlaBuilder. This is only intended for export to MHLO or @@ -2038,7 +2076,8 @@ XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno, // dimension dimension if dims_are_dynamic[i] is true. XlaOp DynamicReshape(XlaOp operand, absl::Span dim_sizes, absl::Span new_size_bounds, - const std::vector& dims_are_dynamic); + const std::vector& dims_are_dynamic, + absl::Span expressions); // This is an experimental API for creating the mhlo.dynamic_reshape op from the // XlaBuilder. This is only intended for export to MHLO or StableHLO, and cannot @@ -2050,6 +2089,9 @@ XlaOp MhloDynamicReshape(XlaOp operand, XlaOp output_shape, const Shape& shape); // dimension sizes. Conceptually, this is a limited form of "shape casting". XlaOp Reshape(XlaOp operand, absl::Span dimensions); +XlaOp Reshape(XlaOp operand, absl::Span dimensions, + absl::Span expressions); + // Enqueues a Reshape op that uses an explicit target shape. XlaOp Reshape(const Shape& shape, XlaOp operand); @@ -2059,6 +2101,7 @@ XlaOp Reshape(const Shape& shape, XlaOp operand); // is a dynamic dimension in the output, it must be the inferred dimension. XlaOp ReshapeWithInferredDimension(XlaOp operand, absl::Span new_sizes, + absl::Span new_exprs, int64_t inferred_dimension); // Wrapper for Reshape. @@ -2096,6 +2139,12 @@ XlaOp Slice(XlaOp operand, absl::Span start_indices, absl::Span limit_indices, absl::Span strides); +XlaOp Slice(XlaOp operand, absl::Span start_indices, + absl::Span limit_indices, + absl::Span start_exprs, + absl::Span limit_exprs, + absl::Span strides); + // Enqueues a slice operation in a given dimension, taking all other // dimensions as they are; e.g. if dimno is 1 from start_index 2 to // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand @@ -2105,6 +2154,10 @@ XlaOp Slice(XlaOp operand, absl::Span start_indices, XlaOp SliceInDim(XlaOp operand, int64_t start_index, int64_t limit_index, int64_t stride, int64_t dimno); +XlaOp SliceInDim(XlaOp operand, int64_t start_index, int64_t limit_index, + DynExpr* start_expr, DynExpr* limit_expr, + int64_t stride, int64_t dimno); + // Enqueues a slice operation onto the computation that slices the 'operand' // from dynamic start indices which are passed in 'start_indices'. // The size of the slice in each dimension is passed in 'slice_sizes', @@ -2116,7 +2169,8 @@ XlaOp SliceInDim(XlaOp operand, int64_t start_index, int64_t limit_index, // Slice index calculations are computed modulo input dimension sizes to // prevent dynamic start indices from generating out-of-bound array accesses. XlaOp DynamicSlice(XlaOp operand, absl::Span start_indices, - absl::Span slice_sizes); + absl::Span slice_sizes, + absl::Span slice_exprs = {}); // Enqueues a dynamic update slice operation onto the computation, which // updates a slice of 'operand' with 'update' at dynamic 'start_indices'. diff --git a/third_party/xla/xla/hlo/pass/hlo_pass_pipeline.cc b/third_party/xla/xla/hlo/pass/hlo_pass_pipeline.cc index 07cf9a8c4bce2c..11eeefa948fa1b 100644 --- a/third_party/xla/xla/hlo/pass/hlo_pass_pipeline.cc +++ b/third_party/xla/xla/hlo/pass/hlo_pass_pipeline.cc @@ -250,6 +250,7 @@ absl::StatusOr HloPassPipeline::RunPassesInternal( } TF_RETURN_IF_ERROR(status); } + if (!pass->IsPassPipeline()) { compilation_stats_->EndPass(pass_name); } diff --git a/third_party/xla/xla/hlo/transforms/collectives/all_gather_combiner.cc b/third_party/xla/xla/hlo/transforms/collectives/all_gather_combiner.cc index 61b261248e0e2c..aa781b082c8823 100644 --- a/third_party/xla/xla/hlo/transforms/collectives/all_gather_combiner.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/all_gather_combiner.cc @@ -124,9 +124,9 @@ absl::Status CombineAllGathers(absl::Span to_combine, (*perm)[ag->all_gather_dimension()]); // Bitcast operand and update output shape. + auto sh = ShapeUtil::PermuteDimensions(*perm, operand_shape); operands.back() = - computation.AddInstruction(HloInstruction::CreateBitcast( - ShapeUtil::PermuteDimensions(*perm, operand_shape), operand)); + computation.AddInstruction(HloInstruction::CreateBitcast(sh, operand)); output_shapes.back() = ShapeUtil::PermuteDimensions(*perm, hlo->shape()); } } diff --git a/third_party/xla/xla/hlo/transforms/expanders/bitcast_dtypes_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/bitcast_dtypes_expander.cc index ddb505801d92ca..02dd7813d033de 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/bitcast_dtypes_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/bitcast_dtypes_expander.cc @@ -80,10 +80,13 @@ absl::StatusOr BitcastDtypesExpander::ExpandInstruction( broadcasted_input_shape.push_back(input_bit_width / output_bit_width); reshaped_input_shape.push_back(1); int64_t output_bit_width_mask = (int64_t{1} << output_bit_width) - 1; - - TF_ASSIGN_OR_RETURN(input, - BroadcastTo(Reshape(input, reshaped_input_shape), - broadcasted_input_shape)); + std::vector reshaped_input_exprs( + from_shape.expressions().begin(), from_shape.expressions().end()); + reshaped_input_exprs.push_back(DynExpr::_(1)); + TF_ASSIGN_OR_RETURN( + input, BroadcastTo( + Reshape(input, reshaped_input_shape, reshaped_input_exprs), + broadcasted_input_shape)); input = BitcastConvertType(input, input_logical_type); TF_ASSIGN_OR_RETURN(Shape input_shape, b.GetShape(input)); XlaOp iota = Iota(&b, input_shape, input_shape.dimensions().size() - 1); diff --git a/third_party/xla/xla/hlo/transforms/expanders/cholesky_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/cholesky_expander.cc index b5df50e1956c90..7b11a211ee3e02 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/cholesky_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/cholesky_expander.cc @@ -218,9 +218,9 @@ XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64_t block_size, l = UpdateSliceInMinorDims(l, update, {i + k, i}); } } - return Select( - BroadcastInDim(seen_error, a_shape.dimensions(), error_dim_indices), - FullLike(l, std::numeric_limits::quiet_NaN()), l); + return Select(BroadcastInDim(seen_error, a_shape.dimensions(), + error_dim_indices, a_shape.expressions()), + FullLike(l, std::numeric_limits::quiet_NaN()), l); }); } diff --git a/third_party/xla/xla/hlo/transforms/expanders/dot_decomposer.cc b/third_party/xla/xla/hlo/transforms/expanders/dot_decomposer.cc index a3787d88ebbd93..a7d62392dec79a 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/dot_decomposer.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/dot_decomposer.cc @@ -80,26 +80,42 @@ absl::Status CanonicalizeDot(HloDotInstruction* original_dot) { lhs_non_contracting_dims.reserve(num_lhs_non_contracting_dims); int64_t lhs_contracting_size = 1; bool lhs_contracting_dynamic = false; + int64_t lhs_contracting_multiplier_accu = 1; + DynExpr* lhs_contracting_expression = DynExpr::one; int64_t lhs_non_contracting_size = 1; bool lhs_non_contracting_dynamic = false; + int64_t lhs_non_contracting_multiplier_accu = 1; + DynExpr* lhs_non_contracting_expression = DynExpr::one; std::vector batch_dim_sizes; batch_dim_sizes.reserve(num_batch_dims); std::vector batch_dynamic_dims; batch_dynamic_dims.reserve(num_batch_dims); + std::vector batch_expressions; + batch_expressions.reserve(num_batch_dims); + + bool lhs_contracting_is_static = true; + bool lhs_non_contracting_is_static = true; + for (int64_t i = 0; i < lhs_rank; ++i) { if (absl::c_linear_search(original_dnums.lhs_contracting_dimensions(), i)) { lhs_contracting_size *= lhs_shape.dimensions(i); lhs_contracting_dynamic |= lhs_shape.is_dynamic_dimension(i); + lhs_contracting_expression = + (*lhs_contracting_expression) * (*lhs_shape.expressions(i)); } else if (absl::c_linear_search(original_dnums.lhs_batch_dimensions(), i)) { batch_dim_sizes.push_back(lhs_shape.dimensions(i)); batch_dynamic_dims.push_back(lhs_shape.is_dynamic_dimension(i)); + batch_expressions.push_back(lhs_shape.expressions(i)); } else { lhs_non_contracting_dims.push_back(i); lhs_non_contracting_size *= lhs_shape.dimensions(i); lhs_non_contracting_dynamic |= lhs_shape.is_dynamic_dimension(i); + lhs_non_contracting_expression = + (*lhs_non_contracting_expression) * (*lhs_shape.expressions(i)); } } + // The canonical form of the lhs is // [BatchDims, NonContractingDimsProduct, ContractingsDimsProduct] // If NonContractingDimsProduct is 1, it is omitted. @@ -123,18 +139,21 @@ absl::Status CanonicalizeDot(HloDotInstruction* original_dot) { std::vector lhs_reshape_dims = batch_dim_sizes; std::vector lhs_reshape_dynamic_dims = batch_dynamic_dims; + std::vector lhs_reshape_expressions = batch_expressions; if (lhs_non_contracting_size > 1) { lhs_reshape_dims.push_back(lhs_non_contracting_size); lhs_reshape_dynamic_dims.push_back(lhs_non_contracting_dynamic); + lhs_reshape_expressions.push_back(lhs_non_contracting_expression->s()); } lhs_reshape_dims.push_back(lhs_contracting_size); lhs_reshape_dynamic_dims.push_back(lhs_contracting_dynamic); + lhs_reshape_expressions.push_back(lhs_contracting_expression->s()); // Reshape the contracting and non-contracting dimensions together. + auto sh_lhs = ShapeUtil::MakeShape(lhs_shape.element_type(), lhs_reshape_dims, + lhs_reshape_dynamic_dims, + lhs_reshape_expressions); HloInstruction* reshaped_lhs = computation->AddInstruction( - HloInstruction::CreateReshape( - ShapeUtil::MakeShape(lhs_shape.element_type(), lhs_reshape_dims, - lhs_reshape_dynamic_dims), - transposed_lhs), + HloInstruction::CreateReshape(sh_lhs, transposed_lhs), &transposed_lhs->metadata()); const auto& rhs_shape = original_dot->operand(1)->shape(); @@ -145,17 +164,29 @@ absl::Status CanonicalizeDot(HloDotInstruction* original_dot) { rhs_non_contracting_dims.reserve(num_rhs_non_contracting_dims); int64_t rhs_non_contracting_size = 1; bool rhs_non_contracting_dynamic = false; + int64_t rhs_non_contracting_multiplier_accu = 1; + DynExpr* rhs_non_contracting_expression = DynExpr::one; int64_t rhs_contracting_size = 1; bool rhs_contracting_dynamic = false; + int64_t rhs_contracting_multiplier_accu = 1; + DynExpr* rhs_contracting_expression = DynExpr::one; + + bool rhs_contracting_is_static = true; + bool rhs_non_contracting_is_static = true; + for (int64_t i = 0; i < rhs_rank; ++i) { if (absl::c_linear_search(original_dnums.rhs_contracting_dimensions(), i)) { rhs_contracting_size *= rhs_shape.dimensions(i); rhs_contracting_dynamic |= rhs_shape.is_dynamic_dimension(i); + rhs_contracting_expression = + (*rhs_contracting_expression) * (*rhs_shape.expressions(i)); } else if (!absl::c_linear_search(original_dnums.rhs_batch_dimensions(), i)) { rhs_non_contracting_dims.push_back(i); rhs_non_contracting_size *= rhs_shape.dimensions(i); rhs_non_contracting_dynamic |= rhs_shape.is_dynamic_dimension(i); + rhs_non_contracting_expression = + (*rhs_non_contracting_expression) * (*rhs_shape.expressions(i)); } } @@ -184,27 +215,35 @@ absl::Status CanonicalizeDot(HloDotInstruction* original_dot) { rhs_reshape_dims.push_back(rhs_contracting_size); std::vector rhs_reshape_dynamic_dims = batch_dynamic_dims; rhs_reshape_dynamic_dims.push_back(rhs_contracting_dynamic); + std::vector rhs_reshape_expressions = batch_expressions; + rhs_reshape_expressions.push_back(rhs_contracting_expression->s()); if (rhs_non_contracting_size > 1) { rhs_reshape_dims.push_back(rhs_non_contracting_size); rhs_reshape_dynamic_dims.push_back(rhs_non_contracting_dynamic); + rhs_reshape_expressions.push_back(rhs_non_contracting_expression->s()); } // Reshape the contracting and non-contracting dimensions together. + auto sh_rhs = ShapeUtil::MakeShape(rhs_shape.element_type(), rhs_reshape_dims, + rhs_reshape_dynamic_dims, + rhs_reshape_expressions); HloInstruction* reshaped_rhs = computation->AddInstruction( HloInstruction::CreateReshape( - ShapeUtil::MakeShape(rhs_shape.element_type(), rhs_reshape_dims, - rhs_reshape_dynamic_dims), + sh_rhs, transposed_rhs), &transposed_rhs->metadata()); std::vector dot_dims = batch_dim_sizes; std::vector dot_dynamic_dims = batch_dynamic_dims; + std::vector dot_expressions = batch_expressions; if (lhs_non_contracting_size > 1) { dot_dims.push_back(lhs_non_contracting_size); dot_dynamic_dims.push_back(lhs_non_contracting_dynamic); + dot_expressions.push_back(lhs_non_contracting_expression->s()); } if (rhs_non_contracting_size > 1) { dot_dims.push_back(rhs_non_contracting_size); dot_dynamic_dims.push_back(rhs_non_contracting_dynamic); + dot_expressions.push_back(rhs_non_contracting_expression->s()); } DotDimensionNumbers dot_dnums; @@ -251,12 +290,13 @@ absl::Status CanonicalizeDot(HloDotInstruction* original_dot) { HloInstruction::CreateReshape(result_shape, meta), &meta->metadata()); sparse_meta.push_back(meta); } + auto sh_dot = + ShapeUtil::MakeShape(original_dot->shape().element_type(), dot_dims, + dot_dynamic_dims, dot_expressions); HloInstruction* dot = computation->AddInstruction(HloInstruction::CreateDot( - ShapeUtil::MakeShape(original_dot->shape().element_type(), dot_dims, - dot_dynamic_dims), - reshaped_lhs, reshaped_rhs, dot_dnums, original_dot->precision_config(), - sparsity, sparse_meta)); + sh_dot, reshaped_lhs, reshaped_rhs, dot_dnums, + original_dot->precision_config(), sparsity, sparse_meta)); original_dot->SetupDerivedInstruction(dot); std::unique_ptr replacement = diff --git a/third_party/xla/xla/hlo/transforms/expanders/eigh_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/eigh_expander.cc index 33752d60cae8ce..6ddf023ffd77a7 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/eigh_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/eigh_expander.cc @@ -158,8 +158,10 @@ void ApplyJacobiRotationOverRows(Eigh2x2 rotation, XlaOp& tl, XlaOp& tr, Shape shape = tl.builder()->GetShape(tl).value(); std::vector broadcast_dims(shape.dimensions().size() - 1); absl::c_iota(broadcast_dims, 0); - auto c = BroadcastInDim(rotation.c, shape.dimensions(), broadcast_dims); - auto s = BroadcastInDim(rotation.s, shape.dimensions(), broadcast_dims); + auto c = BroadcastInDim(rotation.c, shape.dimensions(), broadcast_dims, + shape.expressions()); + auto s = BroadcastInDim(rotation.s, shape.dimensions(), broadcast_dims, + shape.expressions()); auto s_conj = MaybeConjugate(s, true); std::tie(tl, tr, bl, br) = @@ -179,8 +181,10 @@ void ApplyJacobiRotationOverCols(Eigh2x2 rotation, XlaOp& tl, XlaOp& tr, std::vector broadcast_dims(shape.dimensions().size() - 1); absl::c_iota(broadcast_dims, 0); broadcast_dims.back() = shape.dimensions().size() - 1; - auto c = BroadcastInDim(rotation.c, shape.dimensions(), broadcast_dims); - auto s = BroadcastInDim(rotation.s, shape.dimensions(), broadcast_dims); + auto c = BroadcastInDim(rotation.c, shape.dimensions(), broadcast_dims, + shape.expressions()); + auto s = BroadcastInDim(rotation.s, shape.dimensions(), broadcast_dims, + shape.expressions()); auto s_conj = MaybeConjugate(s, true); std::tie(tl, tr, bl, br) = @@ -365,11 +369,12 @@ absl::Status EighExpander::SortByEigenvalues(XlaOp& v, XlaOp& w) { TF_ASSIGN_OR_RETURN(Shape w_shape, builder->GetShape(w)); const int64_t num_dims = v_shape.dimensions().size(); auto dimensions = v_shape.dimensions(); + auto expressions = v_shape.expressions(); std::vector broadcast_dims(num_dims - 1); std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); broadcast_dims[num_dims - 2] = num_dims - 1; - w = BroadcastInDim(w, dimensions, broadcast_dims); + w = BroadcastInDim(w, dimensions, broadcast_dims, expressions); XlaOp sort_result = Sort({w, v}, diff --git a/third_party/xla/xla/hlo/transforms/expanders/qr_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/qr_expander.cc index dcf329f7c6d8b9..27b654587ba01a 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/qr_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/qr_expander.cc @@ -58,6 +58,15 @@ std::vector ConcatVectors(absl::Span xs, return output; } +std::vector ConcatEVectors(absl::Span xs, + absl::Span ys) { + std::vector output; + output.reserve(xs.size() + ys.size()); + std::copy(xs.begin(), xs.end(), std::back_inserter(output)); + std::copy(ys.begin(), ys.end(), std::back_inserter(output)); + return output; +} + // Computes sqrt(x^2 + y^2 + ...), avoiding overflow/underflow. // e.g. for 3 arguments: // def norm(x, y, z): @@ -220,11 +229,15 @@ absl::StatusOr QrExpander::QrBlock( const int64_t m = ShapeUtil::GetDimension(a_shape, -2); const int64_t n = ShapeUtil::GetDimension(a_shape, -1); + DynExpr* m_exp = ShapeUtil::GetExpression(a_shape, -2); + DynExpr* n_exp = ShapeUtil::GetExpression(a_shape, -1); const int64_t num_batch_dims = num_dims - 2; std::vector batch_dims(num_batch_dims); + std::vector batch_exprs(num_batch_dims); for (int i = 0; i < num_batch_dims; ++i) { batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); + batch_exprs[i] = ShapeUtil::GetExpression(a_shape, i); } std::vector batch_dim_indices(num_batch_dims); @@ -248,9 +261,12 @@ absl::StatusOr QrExpander::QrBlock( minor_dim + 1); std::vector shape = batch_dims; + std::vector exprs = batch_exprs; shape.push_back(1); shape.push_back(m); - auto v_broadcast = Reshape(v, shape); + exprs.push_back(DynExpr::one); + exprs.push_back(m_exp); + auto v_broadcast = Reshape(v, shape, exprs); // a[:, j+1:] -= np.conj(tau) * (v[:, np.newaxis] @ // (np.conj(v[np.newaxis, :]) @ a[:, j+1:])) // We use masking rather than a loop-variant shape to handle the j+1: @@ -263,7 +279,8 @@ absl::StatusOr QrExpander::QrBlock( // a[j, j] = beta // a[j+1:,j] = v[j+1:] - auto iota = Reshape(Iota(a.builder(), S32, m), {m, 1}); + auto iota = + Reshape(Iota(a.builder(), S32, m), {m, 1}, {m_exp, DynExpr::one}); auto predecessor_mask = ConvertElementType(Lt(iota, j), type); auto mask = Broadcast(ConvertElementType(Eq(iota, j), type), std::vector(batch_dims.size(), 1)); @@ -279,7 +296,8 @@ absl::StatusOr QrExpander::QrBlock( std::vector dim_ids(num_dims); std::iota(dim_ids.begin(), dim_ids.end(), 0); new_x = BroadcastInDim(new_x, ConcatVectors(batch_dims, {m, n}), - /*broadcast_dimensions=*/dim_ids); + /*broadcast_dimensions=*/dim_ids, + ConcatEVectors(batch_exprs, {m_exp, n_exp})); a = Select(Eq(iota_mn, j), new_x, a); // taus[j] = tau diff --git a/third_party/xla/xla/hlo/transforms/expanders/rng_bit_generator_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/rng_bit_generator_expander.cc index d41aaf49be5d32..2a6f428b65aa82 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/rng_bit_generator_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/rng_bit_generator_expander.cc @@ -66,7 +66,7 @@ RngBitGeneratorExpander::GetGeneratorComputation(const Shape& data_shape, XlaBuilder builder("rng"); XlaOp state_param = Parameter(&builder, 0, state_shape, "state"); - XlaOp key_op = Reshape(Slice(state_param, {0}, {1}, {1}), {}); + XlaOp key_op = Reshape(Slice(state_param, {0}, {1}, {1}), {}, {}); RngOutput output; switch (algorithm) { case RandomAlgorithm::RNG_THREE_FRY: @@ -83,8 +83,8 @@ RngBitGeneratorExpander::GetGeneratorComputation(const Shape& data_shape, RandomAlgorithm_Name(algorithm)); } - XlaOp final_state = - ConcatInDim(&builder, {Reshape(key_op, {1}), output.state}, 0); + XlaOp final_state = ConcatInDim( + &builder, {Reshape(key_op, {1}, {DynExpr::one}), output.state}, 0); Tuple(&builder, {final_state, output.value}); TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build()); TF_ASSIGN_OR_RETURN(HloComputation * new_computation, diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index 8391636e00daac..90bf284e206a77 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -2673,6 +2673,8 @@ LogicalResult ExportXlaOp(DynamicReshapeOp op, OpLoweringContext ctx) { SmallVector dimSizes; SmallVector newSizeBounds; std::vector dimsAreDynamic; + std::vector dimExpressions; + for (auto i = 0; i < resultType.getRank(); ++i) { auto runtimeSizeX1 = xla::Slice(outputShape, {i}, {i + 1}, {1}); dimSizes.push_back(xla::Reshape(runtimeSizeX1, {})); @@ -2683,9 +2685,10 @@ LogicalResult ExportXlaOp(DynamicReshapeOp op, OpLoweringContext ctx) { return op->emitOpError() << "unbounded dynamism is not supported"; newSizeBounds.push_back(hlo::isStaticDimSize(dimSize) ? dimSize : dimBound); dimsAreDynamic.push_back(!hlo::isStaticDimSize(dimSize)); + dimExpressions.push_back(xla::DynExpr::_(-40)); // Don't know. } - value_map[op] = - xla::DynamicReshape(operand, dimSizes, newSizeBounds, dimsAreDynamic); + value_map[op] = xla::DynamicReshape(operand, dimSizes, newSizeBounds, + dimsAreDynamic, dimExpressions); return success(); } @@ -2696,7 +2699,8 @@ LogicalResult ExportXlaOp(ReshapeOp op, OpLoweringContext ctx) { return failure(); value_map[op] = - xla::Reshape(operand, xla::TypeToShape(op.getType()).dimensions()); + xla::Reshape(operand, xla::TypeToShape(op.getType()).dimensions(), + xla::TypeToShape(op.getType()).expressions()); return success(); } @@ -3004,6 +3008,7 @@ LogicalResult ExportXlaOp(DynamicReshapeOp op, OpLoweringContext ctx) { SmallVector dimSizes; SmallVector newSizeBounds; std::vector dimsAreDynamic; + std::vector dimExpressions; for (auto i = 0; i < resultType.getRank(); ++i) { auto runtimeSizeX1 = xla::Slice(outputShape, {i}, {i + 1}, {1}); dimSizes.push_back(xla::Reshape(runtimeSizeX1, {})); @@ -3014,9 +3019,11 @@ LogicalResult ExportXlaOp(DynamicReshapeOp op, OpLoweringContext ctx) { return op->emitOpError() << "unbounded dynamism is not supported"; newSizeBounds.push_back(hlo::isStaticDimSize(dimSize) ? dimSize : dimBound); dimsAreDynamic.push_back(!hlo::isStaticDimSize(dimSize)); + dimExpressions.push_back(xla::DynExpr::_(-50)); // Don't know } value_map[op] = - xla::DynamicReshape(operand, dimSizes, newSizeBounds, dimsAreDynamic); + xla::DynamicReshape(operand, dimSizes, newSizeBounds, dimsAreDynamic, + dimExpressions); return success(); } @@ -4558,7 +4565,8 @@ LogicalResult ExportXlaOp(ReshapeOp op, OpLoweringContext ctx) { return failure(); value_map[op] = - xla::Reshape(operand, xla::TypeToShape(op.getType()).dimensions()); + xla::Reshape(operand, xla::TypeToShape(op.getType()).dimensions(), + xla::TypeToShape(op.getType()).expressions()); return success(); } diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.cc index 7bfbbff39956ce..9369b89d77ec61 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.cc @@ -150,6 +150,7 @@ Shape TypeToShape(mlir::Type type) { llvm::SmallVector shape(rank, mlir::ShapedType::kDynamic); std::vector is_dynamic(rank, false); + std::vector expressions(rank, DynExpr::_(-60)); for (int64_t dim = 0; dim < rank; ++dim) { int64_t size = t.getDimSize(dim); if (size == ShapedType::kDynamic) { @@ -191,7 +192,8 @@ Shape TypeToShape(mlir::Type type) { return sparse_shape; } - return ShapeUtil::MakeShape(primitive_type, shape, is_dynamic); + return ShapeUtil::MakeShape(primitive_type, shape, is_dynamic, + expressions); } else if (auto tuple_type = mlir::dyn_cast(type)) { llvm::SmallVector shapes; shapes.reserve(tuple_type.size()); diff --git a/third_party/xla/xla/service/cpu/dot_op_emitter.cc b/third_party/xla/xla/service/cpu/dot_op_emitter.cc index c3f432897567de..5ec762ae5cd56c 100644 --- a/third_party/xla/xla/service/cpu/dot_op_emitter.cc +++ b/third_party/xla/xla/service/cpu/dot_op_emitter.cc @@ -167,6 +167,9 @@ DotImplementationStrategy GetNonBatchDotImplementationStrategy( bool allow_runtime_calls) { PrimitiveType element_type = dot_info.result_shape.element_type(); + // Force Eigen all the time. + return DotImplementationStrategy::kEigen; + // Batched dot either handled by a runtime call or expanded into a sequence // of non-batch dot operations. DCHECK(dot_info.dim_nums.lhs_batch_dimensions_size() == 0 && @@ -253,6 +256,12 @@ class DotOpEmitter { // The number of columns on the RHS. int64_t n; + DynExpr* m_expr; + + DynExpr* k_expr; + + DynExpr* n_expr; + // True if the LHS matrix is column major. bool lhs_column_major; @@ -858,21 +867,20 @@ absl::Status DotOpEmitter::EmitCallToRuntime() { if (!mat_mult_dims.lhs_column_major) { std::swap(mat_mult_dims.m, mat_mult_dims.n); + std::swap(mat_mult_dims.m_expr, mat_mult_dims.n_expr); std::swap(lhs, rhs); std::swap(transpose_lhs, transpose_rhs); } - // We work under the assumption that only M can be dynamic. - int batch_multiplier = target_array_.GetShape().outer_multiplier(); - llvm::Value* m_val = (batch_multiplier > 0) - ? xla::llvm_ir::GetBatchDimByName(b_, batch_multiplier) - : b_->getInt64(mat_mult_dims.m); - - b_->CreateCall(matmul_func, - {executable_run_options_value_, target_array_.GetBasePointer(), - lhs->GetBasePointer(), rhs->GetBasePointer(), m_val, - b_->getInt64(mat_mult_dims.n), b_->getInt64(mat_mult_dims.k), - b_->getInt32(transpose_lhs), b_->getInt32(transpose_rhs)}); + llvm::Value* m_val = xla::llvm_ir::EmitExpression(b_, mat_mult_dims.m_expr); + llvm::Value* n_val = xla::llvm_ir::EmitExpression(b_, mat_mult_dims.n_expr); + llvm::Value* k_val = xla::llvm_ir::EmitExpression(b_, mat_mult_dims.k_expr); + + b_->CreateCall( + matmul_func, + {executable_run_options_value_, target_array_.GetBasePointer(), + lhs->GetBasePointer(), rhs->GetBasePointer(), m_val, n_val, k_val, + b_->getInt32(transpose_lhs), b_->getInt32(transpose_rhs)}); return absl::OkStatus(); } @@ -988,6 +996,14 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { /*n=*/rhs_shape.dimensions().size() <= 1 ? 1LL : rhs_shape.dimensions(1LL - dim_nums.rhs_contracting_dimensions(0)), + /*m_expr=*/lhs_shape.dimensions().size() <= 1 + ? DynExpr::one + : lhs_shape.expressions(1LL - dim_nums.lhs_contracting_dimensions(0)), + /*k_expr=*/ + lhs_shape.expressions(dim_nums.lhs_contracting_dimensions(0)), + /*n_expr=*/rhs_shape.dimensions().size() <= 1 + ? DynExpr::one + : rhs_shape.expressions(1LL - dim_nums.rhs_contracting_dimensions(0)), /*lhs_column_major=*/is_column_major(lhs_shape), /*lhs_canonical=*/lhs_shape.dimensions().size() <= 1 || dim_nums.lhs_contracting_dimensions(0) == 1, @@ -1019,6 +1035,14 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetBatchMatMultDims() const { /*n=*/rhs_shape.dimensions().size() <= 1 ? 1LL : rhs_shape.dimensions(2LL - dim_nums.rhs_contracting_dimensions(0)), + /*m_expr=*/lhs_shape.dimensions().size() <= 1 + ? DynExpr::one + : lhs_shape.expressions(2LL - dim_nums.lhs_contracting_dimensions(0)), + /*k_expr=*/ + lhs_shape.expressions(1LL + dim_nums.lhs_contracting_dimensions(0)), + /*n_expr=*/rhs_shape.dimensions().size() <= 1 + ? DynExpr::one + : rhs_shape.expressions(2LL - dim_nums.rhs_contracting_dimensions(0)), /*lhs_column_major=*/is_column_major(lhs_shape), /*lhs_canonical=*/lhs_shape.dimensions().size() <= 1 || dim_nums.lhs_contracting_dimensions(0) == 1, @@ -1098,22 +1122,34 @@ absl::Status EmitNonBatchDotOperation( Shape DropFirstDim(const Shape& shape) { absl::Span array_shape_dims(shape.dimensions()); + absl::Span array_shape_exprs(shape.expressions()); array_shape_dims.remove_prefix(1); - return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(), - array_shape_dims); + array_shape_exprs.remove_prefix(1); + return ShapeUtil::MakeShapeWithDescendingLayout( + shape.element_type(), array_shape_dims, array_shape_exprs); } Shape CollapseFirstNDims(const Shape& shape, int64_t n) { absl::Span input_shape_dims(shape.dimensions()); + absl::Span input_expressions(shape.expressions()); int64_t prefix_dim = std::accumulate(input_shape_dims.begin(), input_shape_dims.begin() + n, 1ll, std::multiplies()); + + DynExpr* prefix_expression = std::accumulate( + input_expressions.begin(), input_expressions.begin() + n, DynExpr::one, + [](DynExpr* acc, DynExpr* v) { return (*acc) * (*v); }); + DimensionVector result_dims; + std::vector result_expressions; result_dims.push_back(prefix_dim); + result_expressions.push_back(prefix_expression->s()); std::copy(input_shape_dims.begin() + n, input_shape_dims.end(), std::back_inserter(result_dims)); - return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(), - result_dims); + std::copy(input_expressions.begin() + n, input_expressions.end(), + std::back_inserter(result_expressions)); + return ShapeUtil::MakeShapeWithDescendingLayout( + shape.element_type(), result_dims, result_expressions); } llvm_ir::IrArray CollapseFirstNDims(llvm::IRBuilderBase* b, diff --git a/third_party/xla/xla/service/cpu/ir_emitter.cc b/third_party/xla/xla/service/cpu/ir_emitter.cc index cbcea0cb2d1f14..2925ddf8127838 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter.cc @@ -2067,7 +2067,7 @@ absl::Status IrEmitter::HandleSlice(HloInstruction* slice) { const int64_t memcpy_elements = primitive_elements_per_logical_element * memcpy_logical_elements; - EmitTransferElements(memcpy_dest, memcpy_source, memcpy_elements, + EmitTransferElements(memcpy_dest, memcpy_source, DynExpr::_(memcpy_elements), slice->shape().element_type(), target_array, source_array); @@ -2362,21 +2362,16 @@ absl::Status IrEmitter::HandleOuterBatchValue(HloInstruction* hlo) { TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo)); llvm_ir::IrArray out_array = GetIrArrayFor(hlo); - int64_t multiplier = hlo->operand(0)->shape().outer_multiplier(); - if (multiplier <= 0) { - LOG(ERROR) << "Invalid outer multiplier for GetOuterBatchValue: " - << multiplier; - return absl::InvalidArgumentError( - "Invalid outer multiplier for GetOuterBatchValue"); - } - llvm::Value* bdim_value = llvm_ir::GetBatchDimByName(b(), multiplier); - - llvm_ir::IrArray::Index out_index(/*multidimensional_index=*/{}, hlo->shape(), - b()->getInt32Ty()); + llvm::Value* expr_value = + llvm_ir::EmitExpression(b(), hlo->operand(0)->shape().expressions(0)); - llvm::Value* out_ptr = out_array.EmitArrayElementAddress(out_index, b()); - b()->CreateStore(bdim_value, out_ptr); + auto it = emitted_value_.find(hlo); + if (it == emitted_value_.end()) { + LOG(ERROR) << "No buffer assigned for instruction " << hlo->name(); + } + llvm::Value* dest_ptr = it->second; + b()->CreateStore(expr_value, dest_ptr); return absl::OkStatus(); } @@ -3152,11 +3147,11 @@ absl::Status EmitFastConcatenate( // contiguous subregion in the target buffer starting at target_region_begin. llvm::Value* target_region_begin = target_array.EmitArrayElementAddress(target_index, &b, "target_region"); - int64_t byte_offset_into_target_region = 0; + llvm::Value* byte_offset_into_target_region = b.getInt64(0); - int64_t inner_dims_product = absl::c_accumulate( - inner_dims, int64_t{1}, [&](int64_t product, int64_t inner_dim) { - return product * output_shape.dimensions(inner_dim); + DynExpr* inner_exprs_product = absl::c_accumulate( + inner_dims, DynExpr::one, [&](DynExpr* product, int64_t inner_dim) { + return *product * *output_shape.expressions(inner_dim); }); // For each operand, emit a memcpy from the operand to the target of size @@ -3169,18 +3164,24 @@ absl::Status EmitFastConcatenate( llvm::Value* copy_source_address = source_array.EmitArrayElementAddress(source_index, &b, "src_addr"); - llvm::Value* copy_target_address = - b.CreateGEP(b.getInt8Ty(), target_region_begin, - b.getInt64(byte_offset_into_target_region)); + llvm::Value* copy_target_address = b.CreateGEP( + b.getInt8Ty(), target_region_begin, byte_offset_into_target_region); + + auto cexpr = input_shape.expressions(concat_dim); - ::xla::cpu::EmitTransferElements( - copy_target_address, copy_source_address, - inner_dims_product * input_shape.dimensions(concat_dim), primitive_type, - target_array, source_array, module, b); + ::xla::cpu::EmitTransferElements(copy_target_address, copy_source_address, + (*inner_exprs_product * *cexpr)->s(), + primitive_type, target_array, source_array, + module, b); - byte_offset_into_target_region += inner_dims_product * - input_shape.dimensions(concat_dim) * - primitive_type_size; + llvm::Value* concat_dim_count = xla::llvm_ir::EmitExpression( + &b, (*inner_exprs_product * *input_shape.expressions(concat_dim))->s()); + + llvm::Value* concat_dim_size = + b.CreateMul(concat_dim_count, b.getInt64(primitive_type_size)); + byte_offset_into_target_region = + b.CreateAdd(byte_offset_into_target_region, concat_dim_size, + "byte_offset_into_target_region"); } if (!outer_dims.empty()) { @@ -3391,7 +3392,7 @@ llvm::Value* IrEmitter::EmitCallToFfi(HloCustomCallInstruction* custom_call, } void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, - int64_t element_count, + xla::DynExpr* element_count, PrimitiveType primitive_type, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& source_array) { @@ -3401,7 +3402,8 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, } void EmitTransferElements(llvm::Value* target, llvm::Value* source, - int64_t element_count, PrimitiveType primitive_type, + xla::DynExpr* element_count, + PrimitiveType primitive_type, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& source_array, llvm::Module* module, llvm::IRBuilderBase& b) { @@ -3413,7 +3415,7 @@ void EmitTransferElements(llvm::Value* target, llvm::Value* source, llvm::Type* primitive_llvm_type = llvm_ir::PrimitiveTypeToIrType(primitive_type, module->getContext()); - if (element_count == 1) { + if (element_count == DynExpr::one) { auto* load_instruction = b.CreateAlignedLoad(primitive_llvm_type, source, element_alignment); source_array.AnnotateLoadStoreInstructionWithMetadata(load_instruction); @@ -3421,11 +3423,12 @@ void EmitTransferElements(llvm::Value* target, llvm::Value* source, b.CreateAlignedStore(load_instruction, target, element_alignment); target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction); } else { + auto element_count_value = xla::llvm_ir::EmitExpression(&b, element_count); + llvm::Value* elements_size = + b.CreateMul(element_count_value, b.getInt64(primitive_type_size)); auto* memcpy_instruction = b.CreateMemCpy( target, /*DstAlign=*/llvm::Align(element_alignment), source, - /*SrcAlign=*/llvm::Align(element_alignment), - element_count * primitive_type_size); - + /*SrcAlign=*/llvm::Align(element_alignment), elements_size); // The memcpy does the load and the store internally. The aliasing related // metadata has to reflect that. std::map merged_metadata = @@ -3938,8 +3941,10 @@ llvm::Value* IrEmitter::EmitThreadLocalBufferPointer( if (!target_shape.IsOpaque()) { AttachAlignmentMetadataForLoad(param_address_untyped, target_shape); - AttachDereferenceableMetadataForLoad(param_address_untyped, - target_shape); + if (!target_shape.has_dynamic_expr()) { + AttachDereferenceableMetadataForLoad(param_address_untyped, + target_shape); + } } return param_address_untyped; } @@ -3985,7 +3990,10 @@ llvm::Value* IrEmitter::EmitGlobalBufferPointer( AttachInvariantLoadMetadataForLoad(tempbuf_address_base); AttachAlignmentMetadataForLoad(tempbuf_address_base, allocation.size()); - AttachDereferenceableMetadataForLoad(tempbuf_address_base, allocation.size()); + + if (!target_shape.has_dynamic_expr()) + AttachDereferenceableMetadataForLoad(tempbuf_address_base, + allocation.size()); llvm::Value* tempbuf_address_untyped = tempbuf_address_base; // Any explicit buffer pointer should point to the start of the slice. @@ -4086,9 +4094,38 @@ absl::Status IrEmitter::EmitMemcpy(const HloInstruction& source, llvm::Value* source_value = GetEmittedValueFor(&source); llvm::Value* destination_value = GetEmittedValueFor(&destination); int64_t source_size = ByteSizeOf(source.shape()); - // TODO(b/63762267): Be more aggressive about specifying alignment. - MemCpy(destination_value, /*DstAlign=*/llvm::Align(1), source_value, - /*SrcAlign=*/llvm::Align(1), source_size); + auto shape = source.shape(); + auto expressions = shape.expressions(); + bool is_dynamic = + std::any_of(expressions.begin(), expressions.end(), + [](DynExpr* e) { return e->is_dynamic(); }); + if (is_dynamic) { + llvm::LLVMContext& ctx = b()->getContext(); + llvm::IntegerType* i64Type = llvm::IntegerType::getInt64Ty(ctx); + int64_t dimensions_accu = 1; + DynExpr* expression_accu = DynExpr::one; + for (int i = 0; i < shape.dimensions_size(); i++) { + auto expression = shape.expressions(i); + if (expression->is_dynamic()) { + dimensions_accu *= shape.dimensions(i); + expression_accu = (*expression_accu) * (*expression); + } + } + llvm::Value* expr_value = + xla::llvm_ir::EmitExpression(b(), expression_accu->s()); + // Divide the size in bytes by the size of the dynamic dimension(s). + // TODO: make that less hacky + llvm::ConstantInt* size = + llvm::ConstantInt::get(i64Type, source_size / dimensions_accu, true); + llvm::Value* memcopy_size = + b()->CreateMul(expr_value, size, "memcopy_size"); + MemCpy(destination_value, /*DstAlign=*/llvm::Align(1), source_value, + /*SrcAlign=*/llvm::Align(1), memcopy_size); + } else { + // TODO(b/63762267): Be more aggressive about specifying alignment. + MemCpy(destination_value, /*DstAlign=*/llvm::Align(1), source_value, + /*SrcAlign=*/llvm::Align(1), source_size); + } return absl::OkStatus(); } diff --git a/third_party/xla/xla/service/cpu/ir_emitter.h b/third_party/xla/xla/service/cpu/ir_emitter.h index 0bd7fcc6b377a9..4d4af3805e7b1d 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.h +++ b/third_party/xla/xla/service/cpu/ir_emitter.h @@ -570,7 +570,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, // Emits LLVM IR to transfer "element_count" elements of type "primitive_type" // from the address "source" to the address "target". void EmitTransferElements(llvm::Value* target, llvm::Value* source, - int64_t element_count, PrimitiveType primitive_type, + xla::DynExpr* element_count, PrimitiveType primitive_type, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& source_array); @@ -860,7 +860,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, // Decoupled implementation of IrEmitter::EmitTransferElements. void EmitTransferElements(llvm::Value* target, llvm::Value* source, - int64_t element_count, PrimitiveType primitive_type, + xla::DynExpr* element_count, PrimitiveType primitive_type, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& source_array, llvm::Module* module, llvm::IRBuilderBase& b); diff --git a/third_party/xla/xla/service/cpu/ir_emitter2.cc b/third_party/xla/xla/service/cpu/ir_emitter2.cc index 927d43f9d42fc5..8ce0ad8129a155 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter2.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter2.cc @@ -203,16 +203,10 @@ IrEmitter2::EmitGetOuterBatchValueHostKernel(const HloInstruction* getBatch) { EmitKernelPrototype(getBatch)); llvm_ir::IrArray operand_array = kernel_prototype.arguments[0]; llvm_ir::IrArray output_array = kernel_prototype.results[0]; - int64_t multiplier = getBatch->operand(0)->shape().outer_multiplier(); - if (multiplier <= 0) { - LOG(ERROR) << "Invalid outer multiplier for GetOuterBatchValue: " - << multiplier; - return absl::InvalidArgumentError( - "Invalid outer multiplier for GetOuterBatchValue"); - } + xla::DynExpr* expr = getBatch->operand(0)->shape().expressions(0); llvm::IRBuilder<> b(module_->getContext()); b.SetInsertPoint(kernel_prototype.function->getEntryBlock().getTerminator()); - llvm::Value* bdim_value = llvm_ir::GetBatchDimByName(&b, multiplier); + llvm::Value* bdim_value = llvm_ir::EmitExpression(&b, expr); llvm_ir::IrArray::Index output_index(/*multidimensional_index=*/{}, getBatch->shape(), b.getInt32Ty()); llvm::Value* output_ptr = diff --git a/third_party/xla/xla/service/cpu/parallel_loop_emitter.cc b/third_party/xla/xla/service/cpu/parallel_loop_emitter.cc index adabc35f90c23f..e3af0399381b79 100644 --- a/third_party/xla/xla/service/cpu/parallel_loop_emitter.cc +++ b/third_party/xla/xla/service/cpu/parallel_loop_emitter.cc @@ -61,7 +61,6 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, // performance with a large improvement in compile time. auto unroll_mode = (i == 0) ? llvm_ir::UnrollMode::kDefaultUnroll : llvm_ir::UnrollMode::kNoUnroll; - bool is_batch_dim = (dimension == 0) && shape_.outer_multiplier() > 0; if (bounds_index < dynamic_loop_bounds_->size()) { // Emit dynamic loop bounds for this dimension. Dynamic loop bounds @@ -72,7 +71,7 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, std::unique_ptr loop = loop_nest.AddLoop( /*suffix=*/absl::StrFormat("dim.%d", dimension), start_index, end_index, unroll_mode, /*prevent_vectorization*/ false, - is_batch_dim); + /* expression */ shape_.expressions(dimension)); array_multi_index[dimension] = loop->GetIndVarValue(); } else { // Emit static loop bounds for this dimension. @@ -80,7 +79,8 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, /*start_index=*/0, /*end_index=*/shape_.dimensions(dimension), /*suffix=*/absl::StrFormat("dim.%d", dimension), unroll_mode, - /*prevent_vectorization*/ false, is_batch_dim); + /*prevent_vectorization*/ false, + /* expression */ shape_.expressions(dimension)); array_multi_index[dimension] = loop->GetIndVarValue(); } } diff --git a/third_party/xla/xla/service/elemental_ir_emitter.cc b/third_party/xla/xla/service/elemental_ir_emitter.cc index 2963fd4a03d5b1..649e89e11837c0 100644 --- a/third_party/xla/xla/service/elemental_ir_emitter.cc +++ b/third_party/xla/xla/service/elemental_ir_emitter.cc @@ -3289,8 +3289,9 @@ absl::StatusOr ElementalIrEmitter::EmitElementalConcatenate( cases.emplace_back(current_offset, operand); llvm::Value* cdim = source_index.GetConstantWithIndexType( operand->shape().dimensions(concat_dim)); - if (concat_dim == 0 && operand->shape().outer_multiplier() > 0) { - cdim = llvm_ir::GetBatchDimByName(b_, operand->shape().outer_multiplier()); + if (operand->shape().expressions(concat_dim)->is_dynamic()) { + cdim = llvm_ir::EmitExpression( + b_, operand->shape().expressions(concat_dim)); } current_offset = b_->CreateAdd(current_offset, cdim, "current_offset"); coffset += operand->shape().dimensions(concat_dim); @@ -3625,14 +3626,9 @@ absl::StatusOr ElementalIrEmitter::EmitElementalPad( int64_t shape_dim = hlo->operand(0)->shape().dimensions(i); llvm::Value* bound = index_typed_const(shape_dim); - int64_t multiplier = - (i == 0) ? hlo->operand(0)->shape().outer_multiplier() : -1; - if (multiplier > 0) { - bound = llvm_ir::GetBatchDimByName(b_, multiplier); - } else if (shape_dim == 977) { - // This should be deleted. - LOG(ERROR) << "Dynamic batch marker: No multiplier for batch dim: " - << hlo->ToString(); + if (hlo->operand(0)->shape().expressions(i)->is_dynamic()) { + bound = llvm_ir::EmitExpression( + b_, hlo->operand(0)->shape().expressions(i)); } in_bounds = And(in_bounds, ICmpSLT(multi_index[i], bound), "in_bounds"); @@ -3703,12 +3699,14 @@ absl::StatusOr ElementalIrEmitter::EmitElementalDot( }; llvm::Value* contracted_bound = index_typed_const(contracted_dim_size); - int64_t multiplier = (lhs_contracting_dim == 0) - ? hlo->operand(0)->shape().outer_multiplier() - : -1; - if (multiplier > 0) { - llvm::Value* bdim_value = llvm_ir::GetBatchDimByName(b_, multiplier); - contracted_bound = bdim_value; + + if (!hlo->operand(0) + ->shape() + .expressions(lhs_contracting_dim) + ->is_constant()) { + llvm::Value* expr_value = llvm_ir::EmitExpression( + b_, hlo->operand(0)->shape().expressions(lhs_contracting_dim)); + contracted_bound = expr_value; } std::unique_ptr inner_loop = llvm_ir::ForLoop::EmitForLoop( @@ -3907,9 +3905,17 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const HloInstruction* operand = hlo->operand(0); std::vector source_multi_index = target_index.multidim(); for (int64_t dim : hlo->dimensions()) { - source_multi_index[dim] = Sub(target_index.GetConstantWithIndexType( - hlo->shape().dimensions(dim) - 1), - target_index[dim]); + if (hlo->shape().expressions(dim)->is_dynamic()) { + llvm::Value* one = target_index.GetConstantWithIndexType(1); + llvm::Value* expr_value = + llvm_ir::EmitExpression(b_, hlo->shape().expressions(dim)); + source_multi_index[dim] = + Sub(Sub(expr_value, one), target_index[dim]); + } else { + source_multi_index[dim] = Sub(target_index.GetConstantWithIndexType( + hlo->shape().dimensions(dim) - 1), + target_index[dim]); + } } llvm_ir::IrArray::Index source_index( source_multi_index, operand->shape(), target_index.GetType()); @@ -3991,8 +3997,6 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kSlice: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> absl::StatusOr { - // [Steven] the problem is that slices are represented by integer ranges. - // If these are based on the magic number they are wrong. IrArray::Index sliced_index = index.SourceIndexOfSlice( /*operand_shape=*/hlo->operand(0)->shape(), /*starts=*/hlo->slice_starts(), @@ -4266,12 +4270,10 @@ absl::StatusOr ElementalIrEmitter::EmitElementalReduceWindow( int64_t dim_bound = reduce_window->inputs()[0]->shape().dimensions(i); llvm::Value* shape_bound = index_typed_const(dim_bound); - int64_t multiplier = - (i == 0) ? reduce_window->inputs()[0]->shape().outer_multiplier() : -1; - - if (multiplier > 0) { - llvm::Value* bdim_value = llvm_ir::GetBatchDimByName(b_, multiplier); - shape_bound = bdim_value; + if (reduce_window->inputs()[0]->shape().expressions(i)->is_dynamic()) { + llvm::Value* expr_value = llvm_ir::EmitExpression( + b_, reduce_window->inputs()[0]->shape().expressions(i)); + shape_bound = expr_value; } in_bounds = diff --git a/third_party/xla/xla/service/hlo_creation_utils.cc b/third_party/xla/xla/service/hlo_creation_utils.cc index b815b1ffc79627..3a33c5f89a2e36 100644 --- a/third_party/xla/xla/service/hlo_creation_utils.cc +++ b/third_party/xla/xla/service/hlo_creation_utils.cc @@ -655,9 +655,12 @@ absl::StatusOr CollapseFirstNDims(HloInstruction* operand, CHECK_GE(operand_shape.dimensions_size(), n); int64_t new_shape_leading_bound = 1; bool new_shape_leading_is_dynamic = false; + DynExpr* new_shape_leading_expression = DynExpr::one; for (int64_t i = 0; i < n; i++) { new_shape_leading_bound *= operand_shape.dimensions(i); new_shape_leading_is_dynamic |= operand_shape.is_dynamic_dimension(i); + new_shape_leading_expression = + (*new_shape_leading_expression) * (*operand_shape.expressions(i)); } std::vector new_shape_dims; @@ -675,8 +678,16 @@ absl::StatusOr CollapseFirstNDims(HloInstruction* operand, operand_shape.dynamic_dimensions().end(), std::back_inserter(new_shape_dynamic_dims)); - Shape output_shape = ShapeUtil::MakeShape( - operand_shape.element_type(), new_shape_dims, new_shape_dynamic_dims); + std::vector new_shape_expressions; + new_shape_expressions.reserve(operand_shape.dimensions_size() - n + 1); + new_shape_expressions.push_back(new_shape_leading_expression->s()); + auto exprs = operand_shape.expressions(); + std::copy(exprs.begin() + n, exprs.end(), + std::back_inserter(new_shape_expressions)); + + Shape output_shape = + ShapeUtil::MakeShape(operand_shape.element_type(), new_shape_dims, + new_shape_dynamic_dims, new_shape_expressions); return MakeReshapeHlo(output_shape, operand); } diff --git a/third_party/xla/xla/service/llvm_ir/ir_array.cc b/third_party/xla/xla/service/llvm_ir/ir_array.cc index 389eee09ef9bda..e919ebeb8f0f59 100644 --- a/third_party/xla/xla/service/llvm_ir/ir_array.cc +++ b/third_party/xla/xla/service/llvm_ir/ir_array.cc @@ -281,8 +281,11 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( // linear index by each dimension size. for (int64_t i = common_factors[k + 1].first - 1; i >= common_factors[k].first; --i) { + bool is_dynamic = input_shape.expressions(i)->is_dynamic(); llvm::Value* divisor = - GetConstantWithIndexType(input_shape.dimensions(i)); + is_dynamic ? llvm_ir::EmitExpression(builder, + input_shape.expressions(i)) + : GetConstantWithIndexType(input_shape.dimensions(i)); if (input_shape.dimensions(i) == 1) { source_multidim_index[i] = GetConstantWithIndexType(0); } else if (i == common_factors[k].first) { @@ -560,42 +563,24 @@ llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index, gep_indices.push_back(actual_index[dimension]); } -#define DYN_DIMS -#ifdef DYN_DIMS - - llvm::ArrayType* outerArray = llvm::dyn_cast(pointee_type_); - - CHECK(outerArray) << "Expected outer array type."; - - llvm::Value* gep; - - if (shape_.outer_multiplier() > 0) { - - // Extract the inner array type: [N x T] - llvm::Type* innerArray = outerArray->getElementType(); - - CHECK(innerArray) << "Expected inner array type."; - - // Create a new array type: [0 x [N x T]] - llvm::ArrayType* zeroOuterArray = llvm::ArrayType::get(innerArray, 0); - - llvm::PointerType* newPtrTy = llvm::PointerType::getUnqual(zeroOuterArray); - llvm::Value* castedPtr = b->CreateBitCast(base_ptr_, newPtrTy); - - gep = - b->CreateInBoundsGEP(zeroOuterArray, - castedPtr, - gep_indices, llvm_ir::AsStringRef(name)); + // Do not make a dynamic "GEP" if only the first dimension is dynamic since + // it's always indiced with 0 (i.e. the dynamic dimension has no impact on the + // address computation). + auto expressions = shape_.expressions(); + bool dynamic_first_dim = + expressions[0]->is_dynamic() && + std::all_of(expressions.begin() + 1, expressions.end(), + [](DynExpr* e) { return e->is_constant(); }); + if (!dynamic_first_dim && shape_.has_dynamic_expr()) { + llvm::Type* element_type = + PrimitiveTypeToIrType(shape_.element_type(), b->getContext()); + return llvm_ir::createDynamicGEP( + b, base_ptr_, gep_indices, shape_.dimensions(), expressions, + element_type, llvm_ir::AsStringRef(name)); } else { - gep = b->CreateInBoundsGEP(pointee_type_, base_ptr_, gep_indices, - llvm_ir::AsStringRef(name)); + return b->CreateInBoundsGEP(pointee_type_, base_ptr_, gep_indices, + llvm_ir::AsStringRef(name)); } -#else - auto gep = b->CreateInBoundsGEP(pointee_type_, base_ptr_, gep_indices, - llvm_ir::AsStringRef(name)); -#endif - - return gep; } llvm::Value* IrArray::EmitLinearArrayElementAddress( diff --git a/third_party/xla/xla/service/llvm_ir/llvm_loop.cc b/third_party/xla/xla/service/llvm_ir/llvm_loop.cc index f3159e0aacd04c..562260717ba35d 100644 --- a/third_party/xla/xla/service/llvm_ir/llvm_loop.cc +++ b/third_party/xla/xla/service/llvm_ir/llvm_loop.cc @@ -186,30 +186,29 @@ llvm::BasicBlock* ForLoop::CreateLoopBB(absl::string_view name, return CreateBasicBlock(insert_before_bb_, GetQualifiedName(name), b); } -std::unique_ptr ForLoopNest::AddLoop(absl::string_view suffix, - llvm::Value* start_index, - llvm::Value* end_index, - UnrollMode unroll_mode, - bool prevent_vectorization, - bool is_batch_dim) { +std::unique_ptr ForLoopNest::AddLoop( + absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, + UnrollMode unroll_mode, bool prevent_vectorization, + DynExpr* expression) { return AddLoop(suffix, start_index, end_index, GetConstantWithIndexType(1), - unroll_mode, prevent_vectorization, is_batch_dim); + unroll_mode, prevent_vectorization, expression); } std::unique_ptr ForLoopNest::AddLoop( absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* stride, UnrollMode unroll_mode, bool prevent_vectorization, - bool is_batch_dim) { + DynExpr* expression) { if (inner_loop_body_bb_ != nullptr) { // Create this loop inside the previous one. b_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt()); } llvm::Value* actual_end = end_index; - if (is_batch_dim) { + if (expression && expression->is_dynamic()) { // Get batch dim and compare with end_index to use minimum value - llvm::Value* batch_dim = GetBatchDimByName(b_); - actual_end = b_->CreateSelect(b_->CreateICmpULT(end_index, batch_dim), - end_index, batch_dim, "loop_end_min"); + llvm::Value* expr_value = + llvm_ir::EmitExpression(b_, expression); + actual_end = b_->CreateSelect(b_->CreateICmpULT(end_index, expr_value), + end_index, expr_value, "loop_end_min"); } std::unique_ptr loop(new ForLoop( /*prefix=*/name_, suffix, start_index, actual_end, stride, unroll_mode, @@ -229,18 +228,17 @@ std::unique_ptr ForLoopNest::AddLoop( return loop; } -std::unique_ptr ForLoopNest::AddLoop(int64_t start_index, - int64_t end_index, - absl::string_view suffix, - UnrollMode unroll_mode, - bool prevent_vectorization, - bool is_batch_dim) { +std::unique_ptr ForLoopNest::AddLoop( + int64_t start_index, int64_t end_index, absl::string_view suffix, + UnrollMode unroll_mode, bool prevent_vectorization, + DynExpr* expression) { CHECK_LE(start_index, end_index); - llvm::Value* end = is_batch_dim ? GetBatchDimByName(b_) : GetConstantWithIndexType(end_index); - is_batch_dim = false; + llvm::Value* end = (expression && expression->is_dynamic()) + ? EmitExpression(b_, expression) + : GetConstantWithIndexType(end_index); return AddLoop(suffix, GetConstantWithIndexType(start_index), end, - unroll_mode, prevent_vectorization, is_batch_dim); + unroll_mode, prevent_vectorization); } std::unique_ptr ForLoopNest::AddLoop(int64_t start_index, @@ -248,15 +246,15 @@ std::unique_ptr ForLoopNest::AddLoop(int64_t start_index, absl::string_view suffix, UnrollMode unroll_mode, bool prevent_vectorization, - bool is_batch_dim) { + DynExpr* expression) { CHECK_LE(start_index, end_index); - llvm::Value* end = is_batch_dim ? GetBatchDimByName(b_) : GetConstantWithIndexType(end_index); - is_batch_dim = false; - + llvm::Value* end = (expression && expression->is_dynamic()) + ? EmitExpression(b_, expression) + : GetConstantWithIndexType(end_index); return AddLoop(suffix, GetConstantWithIndexType(start_index), end, GetConstantWithIndexType(stride), unroll_mode, - prevent_vectorization, is_batch_dim); + prevent_vectorization); } IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape, @@ -272,14 +270,13 @@ std::vector ForLoopNest::AddLoopsForShapeOnDimensions( absl::string_view suffix) { std::vector multi_index(shape.dimensions().size()); for (int64_t dimension : dimensions) { - bool is_batch_dim = (dimension == 0) && shape.outer_multiplier() > 0; std::unique_ptr loop = AddLoop( /*start_index=*/0, /*end_index=*/shape.dimensions(dimension), /*suffix=*/ llvm_ir::IrName(suffix, absl::StrCat(dimension)), /*unroll_mode=*/llvm_ir::UnrollMode::kDefaultUnroll, - /*prevent_vectorization=*/false, is_batch_dim); + /*prevent_vectorization=*/false, shape.expressions(dimension)); multi_index[dimension] = loop->GetIndVarValue(); } return multi_index; diff --git a/third_party/xla/xla/service/llvm_ir/llvm_loop.h b/third_party/xla/xla/service/llvm_ir/llvm_loop.h index c6cbcafdda6dbb..5414f019121d7d 100644 --- a/third_party/xla/xla/service/llvm_ir/llvm_loop.h +++ b/third_party/xla/xla/service/llvm_ir/llvm_loop.h @@ -201,14 +201,14 @@ class ForLoopNest { absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* stride, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, - bool prevent_vectorization = false, bool is_batch_dim = false); + bool prevent_vectorization = false, DynExpr* expression = nullptr); // Like the above, except that it defaults to a stride of one. std::unique_ptr AddLoop( absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, - bool prevent_vectorization = false, bool is_batch_dim = false); + bool prevent_vectorization = false, DynExpr* expression = nullptr); // A convenient wrapper of the other flavor of AddLoop. The given start and // end index are constant. @@ -216,13 +216,13 @@ class ForLoopNest { int64_t start_index, int64_t end_index, int64_t stride, absl::string_view suffix, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, - bool prevent_vectorization = false, bool is_batch_dim = false); + bool prevent_vectorization = false, DynExpr* expression = nullptr); // Like the above, except that it defaults to a stride of one. std::unique_ptr AddLoop( int64_t start_index, int64_t end_index, absl::string_view suffix, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, - bool prevent_vectorization = false, bool is_batch_dim = false); + bool prevent_vectorization = false, DynExpr* expression = nullptr); // Add loops to iterate through the indices within the specified // shape. The returned index collects the induction variables of the diff --git a/third_party/xla/xla/service/llvm_ir/llvm_util.cc b/third_party/xla/xla/service/llvm_ir/llvm_util.cc index b337b0dcc5bb32..7790714743fe79 100644 --- a/third_party/xla/xla/service/llvm_ir/llvm_util.cc +++ b/third_party/xla/xla/service/llvm_ir/llvm_util.cc @@ -292,9 +292,13 @@ llvm::Type* ShapeToIrType(const Shape& shape, llvm::LLVMContext& context) { result_type = llvm::ArrayType::get(result_type, shape.tuple_shapes().size()); } else if (shape.IsArray()) { - for (int64_t dimension : LayoutUtil::MinorToMajor(shape)) { - result_type = - llvm::ArrayType::get(result_type, shape.dimensions(dimension)); + auto dimensions = LayoutUtil::MinorToMajor(shape); + for (int i = 0; i < dimensions.size(); i++) { + // The MinorToMajor order reverses dimensions... + bool is_dynamic = + shape.expressions(dimensions.size() - 1 - i)->is_dynamic(); + int64_t dim_val = is_dynamic ? 0 : shape.dimensions(dimensions[i]); + result_type = llvm::ArrayType::get(result_type, dim_val); } } return result_type; @@ -815,7 +819,8 @@ void EmitEarlyReturn(llvm::Value* condition, llvm::IRBuilderBase* b, b->SetInsertPoint(continued, continued->getFirstInsertionPt()); } -llvm::Value* GetBatchDimByName(llvm::IRBuilderBase* b, int64_t multiplier) { +llvm::Value* GetBatchDimByName(llvm::IRBuilderBase* b, int64_t multiplier, + int64_t offset) { llvm::Function* function = b->GetInsertBlock()->getParent(); llvm::LLVMContext& ctx = b->getContext(); llvm::IntegerType* i64Type = llvm::IntegerType::getInt64Ty(ctx); @@ -870,8 +875,80 @@ llvm::Value* GetBatchDimByName(llvm::IRBuilderBase* b, int64_t multiplier) { llvm::ConstantInt* m = llvm::ConstantInt::get(i64Type, multiplier, true); bdim_scaled = b->CreateMul(loadedValue, m, "bdim_scaled"); } + if (offset != 0){ + llvm::ConstantInt* offset_value = + llvm::ConstantInt::get(i64Type, offset, true); + bdim_scaled = b->CreateAdd(bdim_scaled, offset_value, "bdim_offset"); + } return bdim_scaled; } +llvm::Value* EmitExpression(llvm::IRBuilderBase* b, DynExpr* expr) { + llvm::Function* function = b->GetInsertBlock()->getParent(); + llvm::LLVMContext& ctx = b->getContext(); + llvm::IntegerType* i64Type = llvm::IntegerType::getInt64Ty(ctx); + if (expr == nullptr) return nullptr; + if (expr->is_constant()) + return llvm::ConstantInt::get(i64Type, expr->get_val(), true); + if (Variable* var_node = dynamic_cast(expr)) { + // For now we can just use %bdim... + return GetBatchDimByName(b); + } + if (Mul* mul_node = dynamic_cast(expr)) { + llvm::Value* v_lhs = EmitExpression(b, mul_node->get_lhs()); + llvm::Value* v_rhs = EmitExpression(b, mul_node->get_rhs()); + return b->CreateMul(v_lhs, v_rhs, "mul_dims"); + } + // TODO: Check if this should ever happen + if (Div* div_node = dynamic_cast(expr)) { + llvm::Value* v_lhs = EmitExpression(b, div_node->get_lhs()); + llvm::Value* v_rhs = EmitExpression(b, div_node->get_rhs()); + return b->CreateUDiv(v_lhs, v_rhs, "div_dims"); + } + if (Add* add_node = dynamic_cast(expr)) { + llvm::Value* v_lhs = EmitExpression(b, add_node->get_lhs()); + llvm::Value* v_rhs = EmitExpression(b, add_node->get_rhs()); + return b->CreateAdd(v_lhs, v_rhs, "add_dims"); + } + if (Sub* sub_node = dynamic_cast(expr)) { + llvm::Value* v_lhs = EmitExpression(b, sub_node->get_lhs()); + llvm::Value* v_rhs = EmitExpression(b, sub_node->get_rhs()); + return b->CreateSub(v_lhs, v_rhs, "sub_dims"); + } + return nullptr; +} + +llvm::Value* createDynamicGEP(llvm::IRBuilderBase* builder, + llvm::Value* base_ptr, + const std::vector& indices, + absl::Span dims, + absl::Span expressions, + llvm::Type* elem_type, + const llvm::Twine& name) { + llvm::Value* total_index = builder->getInt64(0); + llvm::Type* int64_ty = builder->getInt64Ty(); + + for (size_t i = 0; i < indices.size(); ++i) { + // The stride is the product of all dimensions to the right of this index. + llvm::Value* stride = builder->getInt64(1); + for (size_t j = i; j < dims.size(); ++j) { + if (expressions[j]->is_dynamic()) { + llvm::Value* expr_value = + EmitExpression(builder, expressions[j]); + stride = builder->CreateMul(stride, expr_value, "stride.dyn"); + } else { + stride = builder->CreateMul( + stride, llvm::ConstantInt::get(int64_ty, dims[j]), "stride.static"); + } + } + llvm::Value* scaled_index = + builder->CreateMul(indices[i], stride, "idx.scaled"); + total_index = builder->CreateAdd(total_index, scaled_index, "idx.total"); + } + + // Final GEP: result = base + total_index * sizeof(elem_type) + return builder->CreateGEP(elem_type, base_ptr, total_index, name); +} + } // namespace llvm_ir } // namespace xla diff --git a/third_party/xla/xla/service/llvm_ir/llvm_util.h b/third_party/xla/xla/service/llvm_ir/llvm_util.h index 732e4400e3d467..c868c81574882f 100644 --- a/third_party/xla/xla/service/llvm_ir/llvm_util.h +++ b/third_party/xla/xla/service/llvm_ir/llvm_util.h @@ -334,7 +334,18 @@ llvm::BasicBlock* EmitReturnBlock(llvm::IRBuilderBase* b); void EmitEarlyReturn(llvm::Value* condition, llvm::IRBuilderBase* b, llvm::BasicBlock* return_block = nullptr); -llvm::Value* GetBatchDimByName(llvm::IRBuilderBase* b, int64_t multiplier = 1); +llvm::Value* GetBatchDimByName(llvm::IRBuilderBase* b, int64_t multiplier = 1, + int64_t offset = 0); + +llvm::Value* createDynamicGEP(llvm::IRBuilderBase* builder, + llvm::Value* base_ptr, + const std::vector& indices, + absl::Span dims, + absl::Span expressions, + llvm::Type* elem_type, + const llvm::Twine& name = ""); + +llvm::Value* EmitExpression(llvm::IRBuilderBase* b, DynExpr* expr); } // namespace llvm_ir } // namespace xla diff --git a/third_party/xla/xla/service/llvm_ir/loop_emitter.cc b/third_party/xla/xla/service/llvm_ir/loop_emitter.cc index b0879c8ed4b3de..aafdd02d753fc2 100644 --- a/third_party/xla/xla/service/llvm_ir/loop_emitter.cc +++ b/third_party/xla/xla/service/llvm_ir/loop_emitter.cc @@ -199,9 +199,9 @@ std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( bool dynamic = false; for (int i = 0; i < shape_.dimensions_size(); i++) { - int64_t multiplier = (i == 0) ? shape_.outer_multiplier() : -1; - if (multiplier > 0) { - dynamic_dims[i] = xla::llvm_ir::GetBatchDimByName(b_, multiplier); + auto expr = shape_.expressions(i); + if (expr->is_dynamic()) { + dynamic_dims[i] = xla::llvm_ir::EmitExpression(b_, expr); shape_.set_dynamic_dimension(i, true); dynamic = true; } diff --git a/third_party/xla/xla/service/llvm_ir/tuple_ops.cc b/third_party/xla/xla/service/llvm_ir/tuple_ops.cc index b4b3fa30affbcb..14bf83174ca269 100644 --- a/third_party/xla/xla/service/llvm_ir/tuple_ops.cc +++ b/third_party/xla/xla/service/llvm_ir/tuple_ops.cc @@ -104,7 +104,7 @@ llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64_t index, llvm::LoadInst* src_buffer = b->CreateLoad(element_pointee_type, element_ptr); // Mark the loaded pointer as dereferenceable if we know its shape. - if (!target_shape.IsOpaque()) { + if (!target_shape.IsOpaque() && !target_shape.has_dynamic_expr()) { SetDereferenceableMetadataForLoad( src_buffer, ByteSizeOf(target_shape, src_buffer->getModule()->getDataLayout())); diff --git a/third_party/xla/xla/service/reduce_scatter_combiner.cc b/third_party/xla/xla/service/reduce_scatter_combiner.cc index a86271036f6baa..9b889efb57e3a6 100644 --- a/third_party/xla/xla/service/reduce_scatter_combiner.cc +++ b/third_party/xla/xla/service/reduce_scatter_combiner.cc @@ -131,9 +131,9 @@ absl::Status CombineReduceScatters( std::swap((*perm)[most_frequent_dim], (*perm)[rs->scatter_dimension()]); // Bitcast operand and update output shape. + auto sh = ShapeUtil::PermuteDimensions(*perm, operand_shape), operand; operands.back() = - computation.AddInstruction(HloInstruction::CreateBitcast( - ShapeUtil::PermuteDimensions(*perm, operand_shape), operand)); + computation.AddInstruction(HloInstruction::CreateBitcast(sh)); output_shapes.back() = ShapeUtil::PermuteDimensions(*perm, hlo->shape()); } } diff --git a/third_party/xla/xla/service/shape_inference.cc b/third_party/xla/xla/service/shape_inference.cc index 7985c930d812b5..f8264e3251250f 100644 --- a/third_party/xla/xla/service/shape_inference.cc +++ b/third_party/xla/xla/service/shape_inference.cc @@ -195,6 +195,7 @@ absl::StatusOr InferWindowOutputShape(const Shape& base_shape, std::vector output_dimensions(window.dimensions_size()); std::vector output_is_dynamic(window.dimensions_size()); + std::vector output_expressions(window.dimensions_size()); for (int64_t i = 0; i < window.dimensions_size(); ++i) { const auto& dim = window.dimensions(i); if (dim.size() <= 0) { @@ -230,10 +231,12 @@ absl::StatusOr InferWindowOutputShape(const Shape& base_shape, padded_dilated_base, dilated_window, dim.stride()); } output_is_dynamic[i] = base_shape.is_dynamic_dimension(i); + output_expressions[i] = base_shape.expressions(i); } return ShapeUtil::MakeValidatedShape(element_type, output_dimensions, - output_is_dynamic); + output_is_dynamic, + output_expressions); } // Encapsulates inferred dimension size and bound size. @@ -472,6 +475,7 @@ absl::StatusOr InferMostSpecificDimAndBound(int64_t dim, int64_t last_dim = operand_shape.dimensions_size() - 1; std::vector is_dynamic(operand_shape.dimensions_size()); std::vector dimensions(operand_shape.dimensions_size()); + std::vector expressions(operand_shape.dimensions_size()); TF_RET_CHECK(operand_shape.dimensions(last_dim) >= k) << "k=" << k << " is larger than the last dimension of size=" @@ -480,10 +484,13 @@ absl::StatusOr InferMostSpecificDimAndBound(int64_t dim, is_dynamic[i] = i == last_dim ? false : operand_shape.is_dynamic_dimension(i); dimensions[i] = i == last_dim ? k : operand_shape.dimensions(i); + expressions[i] = + i == last_dim ? xla::DynExpr::_(k) : operand_shape.expressions(i); } - Shape out = ShapeUtil::MakeShape(operand_shape.element_type(), dimensions, - is_dynamic); + Shape out = + ShapeUtil::MakeShape(operand_shape.element_type(), dimensions, is_dynamic, + expressions); Shape idxs_shape = ShapeUtil::ChangeElementType(out, PrimitiveType::S32); return ShapeUtil::MakeTupleShape({out, idxs_shape}); } @@ -542,9 +549,11 @@ absl::StatusOr InferMostSpecificDimAndBound(int64_t dim, int64_t rank = arg_shape->dimensions_size(); std::vector inferred_sizes(rank, Shape::kUnboundedSize); std::vector inferred_bounds(rank, Shape::kUnboundedSize); + std::vector inferred_expressions(rank, DynExpr::zero); // Note: for the concatenate dimension, 0 should be the identity element: // Any dim size can keep unchanged when concatenated with 0 inferred_sizes[dimension] = 0; + inferred_expressions[dimension] = DynExpr::zero; for (const Shape* shape : arg_shapes) { for (int dim = 0; dim < rank; ++dim) { @@ -554,24 +563,32 @@ absl::StatusOr InferMostSpecificDimAndBound(int64_t dim, int64_t leftSize = inferred_sizes[dim]; int64_t rightSize = dimension_size; int64_t leftBound = inferred_bounds[dim]; + xla::DynExpr* leftExpression = inferred_expressions[dim]; int64_t rightBound = shape->is_dynamic_dimension(dim) ? dimension_size : Shape::kUnboundedSize; + xla::DynExpr* rightExpression = shape->expressions(dim); + xla::DynExpr* inferred_expression = xla::DynExpr::zero; + if (dim == dimension) { inferred_dim_and_bound = InferConcatenatedDimAndBound( leftSize, rightSize, leftBound, rightBound); + inferred_expression = *leftExpression + *rightExpression; } else { TF_ASSIGN_OR_RETURN( inferred_dim_and_bound, InferMostSpecificDimAndBound(dim, leftSize, rightSize, leftBound, rightBound)); + inferred_expression = rightExpression; } inferred_sizes[dim] = inferred_dim_and_bound.dimension; inferred_bounds[dim] = inferred_dim_and_bound.bound; + inferred_expressions[dim] = inferred_expression->s(); } } - Shape result = ShapeUtil::MakeShape(element_type, inferred_sizes); + Shape result = + ShapeUtil::MakeShape(element_type, inferred_sizes, inferred_expressions); for (int64_t i = 0; i < inferred_bounds.size(); ++i) { if (!IsUnboundedDynamicSize(inferred_bounds[i]) || IsUnboundedDynamicSize(inferred_sizes[i])) { @@ -756,6 +773,7 @@ absl::StatusOr InferMostSpecificDimAndBound(int64_t dim, std::vector dimensions(operand_shape.dimensions_size()); std::vector is_dynamic(operand_shape.dimensions_size()); + std::vector expressions(operand_shape.dimensions_size()); for (int64_t i = 0; i < operand_shape.dimensions_size(); ++i) { const auto& p = padding_config.dimensions(i); if (operand_shape.is_unbounded_dynamic_dimension(i)) { @@ -771,11 +789,13 @@ absl::StatusOr InferMostSpecificDimAndBound(int64_t dim, } } is_dynamic[i] = operand_shape.is_dynamic_dimension(i); + auto diff = dimensions[i] - operand_shape.dimensions(i); + expressions[i] = (*operand_shape.expressions(i) + diff)->s(); } return ShapeUtil::MakeShape( ShapeUtil::HigherPrecisionElementType(operand_shape, padding_value_shape), - dimensions, is_dynamic); + dimensions, is_dynamic, expressions); } // Current DotDimensionNumbers Requirements: @@ -920,7 +940,9 @@ absl::Status CheckDotDimensionConstraints( void GenerateDotResultDimensions( const Shape& lhs, const Shape& rhs, const DotDimensionNumbers& dimension_numbers, - std::vector& dimensions, std::vector& is_dynamic, + std::vector& dimensions, + std::vector& expressions, + std::vector& is_dynamic, std::vector rhs_group_dimensions = {}) { const auto& lhs_batch_dimensions = dimension_numbers.lhs_batch_dimensions(); const auto lhs_batch_dimensions_size = @@ -930,9 +952,11 @@ void GenerateDotResultDimensions( dimension_numbers.rhs_contracting_dimensions().size() - dimension_numbers.rhs_batch_dimensions().size(); dimensions.reserve(lhs_batch_dimensions_size); + expressions.reserve(lhs_batch_dimensions_size); is_dynamic.reserve(lhs_batch_dimensions_size); for (const int64_t lhs_dim : lhs_batch_dimensions) { dimensions.push_back(lhs.dimensions(lhs_dim)); + expressions.push_back(lhs.expressions(lhs_dim)); is_dynamic.push_back(lhs.is_dynamic_dimension(lhs_dim)); } for (int64_t i = 0; i < lhs.dimensions_size(); i++) { @@ -940,6 +964,7 @@ void GenerateDotResultDimensions( i) && !absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) { dimensions.push_back(lhs.dimensions(i)); + expressions.push_back(lhs.expressions(i)); is_dynamic.push_back(lhs.is_dynamic_dimension(i)); } } @@ -949,6 +974,7 @@ void GenerateDotResultDimensions( !absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i) && !absl::c_linear_search(rhs_group_dimensions, i)) { dimensions.push_back(rhs.dimensions(i)); + expressions.push_back(rhs.expressions(i)); is_dynamic.push_back(rhs.is_dynamic_dimension(i)); } } @@ -990,12 +1016,14 @@ void GenerateDotResultDimensions( std::vector dimensions; std::vector is_dynamic; + std::vector expressions; GenerateDotResultDimensions(lhs, rhs, dimension_numbers, dimensions, - is_dynamic); + expressions, is_dynamic); PrimitiveType type = preferred_element_type.value_or( ShapeUtil::HigherPrecisionElementType(lhs, rhs)); - Shape result = ShapeUtil::MakeShape(type, dimensions, is_dynamic); + Shape result = + ShapeUtil::MakeShape(type, dimensions, is_dynamic, expressions); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result)); VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result); @@ -1193,6 +1221,7 @@ void GenerateDotResultDimensions( PrimitiveType type = preferred_element_type.value_or( ShapeUtil::HigherPrecisionElementType(lhs, rhs)); std::vector dimensions; + std::vector expressions; std::vector is_dynamic; // Add the group dimension to the result shape in case of ragged contracting. if (mode == kContracting) { @@ -1200,9 +1229,10 @@ void GenerateDotResultDimensions( is_dynamic.push_back(is_dynamic_group_sizes); } GenerateDotResultDimensions(lhs, rhs, dimension_numbers, dimensions, - is_dynamic, rhs_group_dimensions); + expressions, is_dynamic, rhs_group_dimensions); - Shape result = ShapeUtil::MakeShape(type, dimensions, is_dynamic); + Shape result = + ShapeUtil::MakeShape(type, dimensions, is_dynamic, expressions); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result)); VLOG(2) << "inferred ragged dot shape: " << ShapeUtil::HumanString(result); return result; @@ -1248,12 +1278,15 @@ void GenerateDotResultDimensions( // Build the resulting shape dimensions. std::vector dimensions; std::vector is_dynamic; + std::vector expressions; for (int64_t i = 0; i < operand_shape.dimensions_size(); ++i) { dimensions.push_back(i != sparsity.dimension() ? operand_shape.dimensions(i) : metadata_dimension_size); is_dynamic.push_back(operand_shape.is_dynamic_dimension(i)); + expressions.push_back(operand_shape.expressions(i)); } - return ShapeUtil::MakeShape(element_type, dimensions, is_dynamic); + return ShapeUtil::MakeShape(element_type, dimensions, is_dynamic, + expressions); } /* static */ absl::StatusOr @@ -1267,6 +1300,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(const Shape& lhs, // from the lhs/rhs pair in every index. std::vector output_dimensions(lhs.dimensions_size()); std::vector output_dimensions_is_dynamic(lhs.dimensions_size()); + std::vector output_dimensions_expressions(lhs.dimensions_size()); for (int64_t i = 0; i < lhs.dimensions_size(); ++i) { if (lhs.dimensions(i) == 1 || rhs.dimensions(i) == 1) { // For the unbounded case, the operand with 1 should be broadcasted to the @@ -1283,7 +1317,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(const Shape& lhs, output_dimensions_is_dynamic[i] = lhs.dimensions(i) == 1 ? rhs.is_dynamic_dimension(i) : lhs.is_dynamic_dimension(i); - } else if (lhs.dimensions(i) == rhs.dimensions(i)) { + output_dimensions_expressions[i] = lhs.dimensions(i) == 1 + ? rhs.expressions(i) + : lhs.expressions(i); + } else if (lhs.dimensions(i) == rhs.dimensions(i)) { // && + // *lhs.expressions(i) == *rhs.expressions(i)) { // LHS | RHS | Result // X | X | X // X | <=X | <=X @@ -1293,6 +1331,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(const Shape& lhs, output_dimensions[i] = lhs.dimensions(i); output_dimensions_is_dynamic[i] = lhs.is_dynamic_dimension(i) || rhs.is_dynamic_dimension(i); + output_dimensions_expressions[i] = lhs.expressions(i); } else if (lhs.is_unbounded_dynamic_dimension(i) || rhs.is_unbounded_dynamic_dimension(i)) { // For the last two rows, consider when <=X turns out to be 1 and ? turns @@ -1309,6 +1348,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(const Shape& lhs, output_dimensions_is_dynamic[i] = lhs.is_unbounded_dynamic_dimension(i) ? rhs.is_dynamic_dimension(i) : lhs.is_dynamic_dimension(i); + output_dimensions_expressions[i] = lhs.is_unbounded_dynamic_dimension(i) + ? rhs.expressions(i) + : lhs.expressions(i); } else { return InvalidArgument("Binary op with incompatible shapes: %s and %s.", ShapeUtil::HumanString(lhs), @@ -1317,7 +1359,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(const Shape& lhs, } return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs), - output_dimensions, output_dimensions_is_dynamic); + output_dimensions, output_dimensions_is_dynamic, + output_dimensions_expressions); } /* static */ absl::StatusOr ShapeInference::InferInDimBroadcastShape( @@ -1398,6 +1441,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(const Shape& lhs, dimension_to_match, larger_shape.dimensions_size())); } int64_t small_dimension_size = smaller_shape.dimensions(i); + DynExpr* small_dimension_exp = smaller_shape.expressions(i); int64_t large_dimension_size = larger_shape.dimensions(dimension_to_match); bool small_is_dynamic = smaller_shape.is_dynamic_dimension(i); bool large_is_dynamic = @@ -1436,6 +1480,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(const Shape& lhs, output_shape.set_dimensions(dimension_to_match, small_dimension_size, small_is_dynamic); + output_shape.set_expression(dimension_to_match, small_dimension_exp); } return output_shape; @@ -1779,7 +1824,9 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { output_shape.element_type(), arg_shape->dimensions(), /*dynamic_dimensions=*/ std::vector(arg_shape->dynamic_dimensions().begin(), - arg_shape->dynamic_dimensions().end())); + arg_shape->dynamic_dimensions().end()), + /*expressions=*/ + arg_shape->expressions()); } /* static */ absl::StatusOr ShapeInference::InferBatchNormTrainingShape( @@ -1864,8 +1911,12 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { const int64_t feature_count = operand_shape.dimensions(feature_index); bool dynamic_feature = operand_shape.is_dynamic_dimension(feature_index); - Shape output_shape_for_mean_and_var = ShapeUtil::MakeShape( - operand_shape.element_type(), {feature_count}, {dynamic_feature}); + DynExpr* expression_feature = + operand_shape.expressions(feature_index); + + Shape output_shape_for_mean_and_var = + ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count}, + {dynamic_feature}, {expression_feature}); if (!CompatibleDimensionSizes(ShapeUtil::GetDimension(offset_shape, 0), feature_count)) { @@ -2148,8 +2199,12 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { const int64_t feature_count = operand_shape.dimensions(feature_index); bool dynamic_feature = operand_shape.is_dynamic_dimension(feature_index); + DynExpr* expression_feature = + operand_shape.expressions(feature_index); + Shape feature_shape = ShapeUtil::MakeShape( - operand_shape.element_type(), {feature_count}, {dynamic_feature}); + operand_shape.element_type(), {feature_count}, {dynamic_feature}, + {expression_feature}); if (!CompatibleDimensionSizes(ShapeUtil::GetDimension(mean_shape, 0), feature_count)) { @@ -2398,13 +2453,17 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { } std::vector dynamic_dimensions(input_spatial_dims.size()); + std::vector expressions(input_spatial_dims.size()); for (auto it = input_spatial_dims.begin(); it != input_spatial_dims.end(); ++it) { dynamic_dimensions[it - input_spatial_dims.begin()] = IsUnboundedDynamicSize(*it); + expressions[it - input_spatial_dims.begin()] = + DynExpr::_(-70); } Shape base_shape = ShapeUtil::MakeShape( - lhs.element_type(), input_spatial_dims, dynamic_dimensions); + lhs.element_type(), input_spatial_dims, dynamic_dimensions, + expressions); TF_ASSIGN_OR_RETURN( Shape window_output_shape, InferWindowOutputShape(base_shape, window, lhs.element_type())); @@ -2457,7 +2516,8 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { } PrimitiveType type = preferred_element_type.value_or( ShapeUtil::HigherPrecisionElementType(lhs, rhs)); - return ShapeUtil::MakeShape(type, dimensions, is_dynamic); + return ShapeUtil::MakeShape(type, dimensions, is_dynamic, + expressions); } /* static */ absl::StatusOr ShapeInference::InferFftShape( @@ -2780,8 +2840,10 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { const std::vector dynamic_dimensions(shape.dynamic_dimensions().begin(), shape.dynamic_dimensions().end()); + auto exprs = shape.expressions(); + std::vector expressions(exprs.begin(), exprs.end()); return ShapeUtil::MakeShape(shape.element_type(), new_dimensions, - dynamic_dimensions); + dynamic_dimensions, expressions); } /* static */ absl::StatusOr ShapeInference::InferAllToAllTupleShape( @@ -2923,23 +2985,28 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { std::vector new_dimensions; std::vector new_is_dynamic; + std::vector new_expressions; for (int i = 0; i < arg.dimensions_size(); ++i) { if (dimensions_to_reduce_set.find(i) == dimensions_to_reduce_set.end()) { new_dimensions.push_back(arg.dimensions(i)); new_is_dynamic.push_back(arg.is_dynamic_dimension(i)); + new_expressions.push_back(arg.expressions(i)); } } if (ShapeUtil::IsScalar(to_apply.result())) { - return ShapeUtil::MakeShape(to_apply.result().element_type(), - new_dimensions, new_is_dynamic); + return ShapeUtil::MakeShape( + to_apply.result().element_type(), new_dimensions, new_is_dynamic, + new_expressions); } else { std::vector result_subshapes; const auto& tuple_shapes = to_apply.result().tuple_shapes(); result_subshapes.reserve(tuple_shapes.size()); for (const Shape& subshape : tuple_shapes) { - result_subshapes.push_back(ShapeUtil::MakeShape( - subshape.element_type(), new_dimensions, new_is_dynamic)); + auto new_shape = ShapeUtil::MakeShape( + subshape.element_type(), new_dimensions, new_is_dynamic, + new_expressions); + result_subshapes.push_back(new_shape); } return ShapeUtil::MakeTupleShape(result_subshapes); } @@ -3172,7 +3239,9 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { /* static */ absl::StatusOr ShapeInference::InferSliceShape( const Shape& arg, absl::Span starts, - absl::Span limits, absl::Span strides) { + absl::Span limits, absl::Span strides, + absl::Span start_exprs, + absl::Span limit_exprs) { auto error = [&](const std::string& message) { return InvalidArgument( "%s in slice operation; argument shape: %s; starts: {%s}; limits: " @@ -3202,8 +3271,10 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { } std::vector sizes; + std::vector expressions; const auto starts_size = starts.size(); sizes.reserve(starts_size); + expressions.reserve(starts_size); for (int64_t dimension = 0; dimension < starts_size; ++dimension) { int64_t start_index = starts[dimension]; int64_t limit_index = limits[dimension]; @@ -3231,6 +3302,14 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { return InvalidArgument("Stride (%d) must be positive.", stride); } sizes.push_back((limit_index - start_index + stride - 1) / stride); + + auto limit_expr = + limit_exprs.empty() ? DynExpr::_(limit_index) : limit_exprs[dimension]; + auto start_expr = + start_exprs.empty() ? DynExpr::_(start_index) : start_exprs[dimension]; + + auto new_expr = (*(*(*limit_expr - *start_expr) + stride) - 1)->s(); + expressions.push_back((*new_expr/stride)->s()); } std::vector is_dynamic(arg.dimensions_size()); @@ -3242,12 +3321,14 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { is_dynamic[i] = arg.is_bounded_dynamic_dimension(i); } - return ShapeUtil::MakeShape(arg.element_type(), sizes, is_dynamic); + return ShapeUtil::MakeShape(arg.element_type(), sizes, is_dynamic, + expressions); } /* static */ absl::StatusOr ShapeInference::InferDynamicSliceShape( const Shape& operand_shape, absl::Span start_index_shapes, - absl::Span slice_sizes, bool allow_scalar_indices) { + absl::Span slice_sizes, + absl::Span slice_exprs, bool allow_scalar_indices) { TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice")); auto number_of_indices = start_index_shapes.size(); // TODO(b/118437727): Remove this path. @@ -3346,8 +3427,8 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { VLOG(2) << StrFormat("slice_sizes[%d] = %d", dim, slice_dim_size); } - Shape result = - ShapeUtil::MakeShape(operand_shape.element_type(), slice_sizes); + Shape result = ShapeUtil::MakeShape(operand_shape.element_type(), slice_sizes, + slice_exprs); for (int64_t dimension = 0; dimension < operand_shape.dimensions_size(); ++dimension) { @@ -3646,7 +3727,8 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { } /* static */ absl::StatusOr ShapeInference::InferBroadcastShape( - const Shape& operand, absl::Span broadcast_sizes) { + const Shape& operand, absl::Span broadcast_sizes, + absl::Span broadcast_exprs) { // This method is used to infer shape for xla::BroadcastInDim. TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast")); TF_RET_CHECK(!operand.is_unbounded_dynamic()); @@ -3666,12 +3748,14 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { std::copy(operand.dimensions().begin(), operand.dimensions().end(), dimensions.begin() + broadcast_sizes.size()); - TF_ASSIGN_OR_RETURN(Shape result, ShapeUtil::MakeValidatedShape( - operand.element_type(), dimensions)); + TF_ASSIGN_OR_RETURN( + Shape result, ShapeUtil::MakeValidatedShape(operand.element_type(), + dimensions, broadcast_exprs)); for (int64_t i = 0; i < operand.dimensions_size(); ++i) { result.set_dynamic_dimension(broadcast_sizes.size() + i, operand.is_dynamic_dimension(i)); } + return result; } @@ -3735,7 +3819,8 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { /* static */ absl::StatusOr ShapeInference::InferDynamicReshapeShape( const Shape& operand, absl::Span dim_size_shapes, absl::Span new_size_bounds, - const std::vector& dims_are_dynamic) { + const std::vector& dims_are_dynamic, + absl::Span expressions) { if (new_size_bounds.size() != dims_are_dynamic.size()) { return InvalidArgument( "DynamicReshape has to have the same number of elements in new_sizes " @@ -3751,9 +3836,8 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { dim_size_shape->ToString()); } } - Shape inferred_shape = ShapeUtil::MakeShape( - operand.element_type(), new_size_bounds, dims_are_dynamic); + operand.element_type(), new_size_bounds, dims_are_dynamic, expressions); if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) { return InvalidArgument( "Reshape operation has mismatched element counts: from=%d (%s) " @@ -3767,10 +3851,21 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { /* static */ absl::StatusOr ShapeInference::InferReshapeShape( const Shape& operand, absl::Span dimensions, - int64_t inferred_dimension) { + int64_t inferred_dimension, absl::Span expressions) { TF_RETURN_IF_ERROR(ExpectArray(operand, "reshape")); Shape inferred_shape = - ShapeUtil::MakeShape(operand.element_type(), dimensions); + ShapeUtil::MakeShape(operand.element_type(), dimensions, expressions); + + if (expressions.empty() && operand.expressions().size() > 0 && + operand.expressions(0)->is_dynamic()) { + return InvalidArgument("Expressions is empty but operand is dynamic"); + } + + // if (!expressions.empty() && expressions[0]->is_constant() && + // expressions[0]->get_val() == 977) { + // return InvalidArgument("Expressions[0] is the magic number (977)."); + // } + VLOG(3) << "Reshape inferred shape: " << ShapeUtil::HumanString(inferred_shape); @@ -4250,18 +4345,25 @@ static absl::Status ValidateGatherDimensionNumbers( std::vector expanded_start_indices_shape; // Also tracks if an output dimension is dynamic. std::vector expanded_start_indices_shape_dynamic_dimensions; + std::vector expanded_start_indices_shape_expressions; expanded_start_indices_shape.reserve(start_indices_shape.dimensions_size()); expanded_start_indices_shape_dynamic_dimensions.reserve( start_indices_shape.dimensions_size()); + expanded_start_indices_shape_expressions.reserve( + start_indices_shape.dimensions_size()); absl::c_copy(start_indices_shape.dimensions(), std::back_inserter(expanded_start_indices_shape)); absl::c_copy( start_indices_shape.dynamic_dimensions(), std::back_inserter(expanded_start_indices_shape_dynamic_dimensions)); + absl::c_copy( + start_indices_shape.expressions(), + std::back_inserter(expanded_start_indices_shape_expressions)); if (expanded_start_indices_shape.size() == gather_dim_numbers.index_vector_dim()) { expanded_start_indices_shape.push_back(1); expanded_start_indices_shape_dynamic_dimensions.push_back(false); + expanded_start_indices_shape_expressions.push_back(DynExpr::one); } TF_RETURN_IF_ERROR(ValidateGatherDimensionNumbers( @@ -4328,10 +4430,12 @@ static absl::Status ValidateGatherDimensionNumbers( output_dim_bounds.reserve(result_rank); std::vector output_dim_is_dynamic; + std::vector output_expressions; output_dim_is_dynamic.reserve(result_rank); for (int64_t i = 0; i < result_rank; i++) { int64_t current_bound; bool dim_dynamic = false; + DynExpr* expression = DynExpr::_(-80); bool is_window_index = absl::c_binary_search(gather_dim_numbers.offset_dims(), i); if (is_window_index) { @@ -4353,6 +4457,9 @@ static absl::Status ValidateGatherDimensionNumbers( if (slice_sizes[offset_dims_seen] == input_shape.dimensions(offset_dims_seen)) { dim_dynamic = input_shape.is_dynamic_dimension(offset_dims_seen); + expression = input_shape.expressions(offset_dims_seen); + } else { + expression = DynExpr::_(slice_sizes[offset_dims_seen]); } current_bound = slice_sizes[offset_dims_seen++]; } else { @@ -4362,15 +4469,17 @@ static absl::Status ValidateGatherDimensionNumbers( // Forward dynamic dimensions from indices. dim_dynamic = expanded_start_indices_shape_dynamic_dimensions[gather_dims_seen]; - + expression = expanded_start_indices_shape_expressions[gather_dims_seen]; current_bound = expanded_start_indices_shape[gather_dims_seen++]; } output_dim_is_dynamic.push_back(dim_dynamic); + output_expressions.push_back(expression); output_dim_bounds.push_back(current_bound); } - return ShapeUtil::MakeShape(input_shape.element_type(), output_dim_bounds, - output_dim_is_dynamic); + auto s = ShapeUtil::MakeShape(input_shape.element_type(), output_dim_bounds, + output_dim_is_dynamic, output_expressions); + return s; } namespace { diff --git a/third_party/xla/xla/service/shape_inference.h b/third_party/xla/xla/service/shape_inference.h index 15818f01cdfa7e..4f5d0f4b2e3efa 100644 --- a/third_party/xla/xla/service/shape_inference.h +++ b/third_party/xla/xla/service/shape_inference.h @@ -245,13 +245,17 @@ class ShapeInference { // e.g. slice f32[32x32] 0:16 0:16 -> f32[16x16] static absl::StatusOr InferSliceShape( const Shape& arg, absl::Span starts, - absl::Span limits, absl::Span strides); + absl::Span limits, absl::Span strides, + absl::Span start_exprs = {}, + absl::Span limit_exprs = {}); // Infers the shape produced by a dynamic slice operation of size specified // in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'. static absl::StatusOr InferDynamicSliceShape( const Shape& operand_shape, absl::Span start_index_shapes, - absl::Span slice_sizes, bool allow_scalar_indices = true); + absl::Span slice_sizes, + absl::Span slice_exprs = {}, + bool allow_scalar_indices = true); // Infers the shape produced by a dynamic update slice operation based // on the shape of operand and update. @@ -283,7 +287,8 @@ class ShapeInference { // Infers the shape produced by a broadcast operation. static absl::StatusOr InferBroadcastShape( - const Shape& operand, absl::Span broadcast_sizes); + const Shape& operand, absl::Span broadcast_sizes, + absl::Span broadcast_exprs = {}); // Checks whether the given parameters can form a broadcast. Returns the same // output_shape if it's legal. @@ -295,7 +300,7 @@ class ShapeInference { // its operand and the new dimension sizes specified. static absl::StatusOr InferReshapeShape( const Shape& operand, absl::Span dimensions, - int64_t inferred_dimension); + int64_t inferred_dimension, absl::Span expressions = {}); // Infers the shape produced by a dynamic reshape operation from the element // type of its operand and the new dimension sizes specified. The result shape @@ -304,7 +309,8 @@ class ShapeInference { static absl::StatusOr InferDynamicReshapeShape( const Shape& operand, absl::Span dim_size_shapes, absl::Span new_size_bounds, - const std::vector& dims_are_dynamic); + const std::vector& dims_are_dynamic, + absl::Span expressions); // Infers the shape produced by a transpose operation from the element type of // its operand and its dimensions field. diff --git a/third_party/xla/xla/service/triangular_solve_expander.cc b/third_party/xla/xla/service/triangular_solve_expander.cc index 67dc0235810de8..ea587461b1e277 100644 --- a/third_party/xla/xla/service/triangular_solve_expander.cc +++ b/third_party/xla/xla/service/triangular_solve_expander.cc @@ -120,7 +120,12 @@ XlaOp DiagonalBlocks(XlaOp a, int64_t block_size) { auto last_blocks_dims = std::vector(ndims); std::copy(shape_dims.begin(), shape_dims.end(), last_blocks_dims.begin()); last_blocks_dims.insert(last_blocks_dims.end() - 2, 1); - last_blocks = Reshape(last_blocks, last_blocks_dims); + auto shape_exprs = blocks_shape.expressions(); + auto last_blocks_exprs = std::vector(ndims); + std::copy(shape_exprs.begin(), shape_exprs.end(), + last_blocks_exprs.begin()); + last_blocks_exprs.insert(last_blocks_exprs.end() - 2, DynExpr::one); + last_blocks = Reshape(last_blocks, last_blocks_dims, last_blocks_exprs); // Concatenate with the other blocks if necessary if (n > block_size) { @@ -366,7 +371,7 @@ XlaOp TriangularSolveExpander::InvertDiagonalBlocks( /*broadcast_dimensions=*/{0, 1}); // Reshape back to original batch major dimensions - return Reshape(inv_diag_blocks, shape.dimensions()); + return Reshape(inv_diag_blocks, shape.dimensions(), shape.expressions()); }); } diff --git a/third_party/xla/xla/shape.cc b/third_party/xla/xla/shape.cc index dbfe4d185dc904..e949ebf4c80fc8 100644 --- a/third_party/xla/xla/shape.cc +++ b/third_party/xla/xla/shape.cc @@ -41,6 +41,295 @@ limitations under the License. namespace xla { +DynExpr* ExprFromProto(const ExpressionProto& proto) { + switch (proto.node_type_case()) { + case ExpressionProto::kConstantValue: + return DynExpr::_(proto.constant_value()); + + case ExpressionProto::kVariableId: + return DynExpr::V(proto.variable_id()); + + case ExpressionProto::kAddNode: { + const auto& add = proto.add_node(); + return *ExprFromProto(add.lhs()) + *ExprFromProto(add.rhs()); + } + + case ExpressionProto::kSubNode: { + const auto& sub = proto.sub_node(); + return *ExprFromProto(sub.lhs()) - *ExprFromProto(sub.rhs()); + } + + case ExpressionProto::kMulNode: { + const auto& mul = proto.mul_node(); + return *ExprFromProto(mul.lhs()) * *ExprFromProto(mul.rhs()); + } + + case ExpressionProto::kDivNode: { + const auto& div = proto.div_node(); + return *ExprFromProto(div.lhs()) / *ExprFromProto(div.rhs()); + } + + case ExpressionProto::NODE_TYPE_NOT_SET: + default: + return nullptr; + } +} + +DynExpr* operator*(DynExpr& lhs, DynExpr& rhs) { return new Mul(&lhs, &rhs); } +DynExpr* operator*(int64_t k, DynExpr& rhs) { + return new Mul(DynExpr::_(k), &rhs); +} +DynExpr* operator/(DynExpr& lhs, DynExpr& rhs) { return new Div(&lhs, &rhs); } +DynExpr* operator/(DynExpr& lhs, int64_t d) { + return new Div(&lhs, DynExpr::_(d)); +} +DynExpr* operator+(DynExpr& lhs, DynExpr& rhs) { return new Add(&lhs, &rhs); } +DynExpr* operator+(DynExpr& lhs, int64_t d) { + return new Add(&lhs, DynExpr::_(d)); +} +DynExpr* operator-(DynExpr& lhs, DynExpr& rhs) { return new Sub(&lhs, &rhs); } +DynExpr* operator-(DynExpr& lhs, int64_t d) { + return new Sub(&lhs, DynExpr::_(d)); +} +bool operator==(DynExpr& lhs, DynExpr& rhs) { + return DynExpr::equal(&lhs, &rhs); +} +bool operator==(DynExpr& lhs, int64_t d) { + return DynExpr::equal(&lhs, DynExpr::_(d)); +} +bool operator<(DynExpr& lhs, int64_t d) { + return lhs.is_constant() && lhs.get_val() < d; +} + +bool DynExpr::equal(DynExpr* expr1, DynExpr* expr2) { + auto e1 = expr1->s(); + auto e2 = expr2->s(); + if (e1 == nullptr || e2 == nullptr) return false; + Constant* c1 = dynamic_cast(e1); + Constant* c2 = dynamic_cast(e2); + if (c1 && c2) return c1->get_val() == c2->get_val(); + // Var x = Var y <=> x = y + if (Variable* varx = dynamic_cast(e1), + *vary = dynamic_cast(e2); + varx && vary) { + return varx->get_id() == vary->get_id(); + } + // a * b = c * d <=> (a = c /\ b = d) \/ (a = d /\ b = c) + if (Mul* ab = dynamic_cast(e1), *cd = dynamic_cast(e2); + ab && cd) { + auto a = ab->get_lhs(); + auto b = ab->get_rhs(); + auto c = cd->get_lhs(); + auto d = cd->get_rhs(); + return (*a == *c && *b == *d) || (*a == *d && *b == *c); + } + // a / b = c / d <=> (a = c /\ b = d) + if (Div* ab = dynamic_cast(e1), *cd = dynamic_cast(e2); + ab && cd) { + auto a = ab->get_lhs(); + auto b = ab->get_rhs(); + auto c = cd->get_lhs(); + auto d = cd->get_rhs(); + return *a == *c && *b == *d; + } + // a + b = c + d <=> (a = c /\ b = d) \/ (a = d /\ b = c) + if (Add* ab = dynamic_cast(e1), *cd = dynamic_cast(e2); + ab && cd) { + auto a = ab->get_lhs(); + auto b = ab->get_rhs(); + auto c = cd->get_lhs(); + auto d = cd->get_rhs(); + return (*a == *c && *b == *d) || (*a == *d && *b == *c); + } + // a - b = c - d <=> (a = c /\ b = d) + if (Sub* ab = dynamic_cast(e1), *cd = dynamic_cast(e2); + ab && cd) { + auto* a = ab->get_lhs(); + auto* b = ab->get_rhs(); + auto* c = cd->get_lhs(); + auto* d = cd->get_rhs(); + return *a == *c && *b == *d; + } + return false; +} + +// Simplification methods +DynExpr* Constant::s() { return this; } + +DynExpr* Variable::s() { return this; } + +DynExpr* Mul::s() { + DynExpr* s_lhs = get_lhs()->s(); + DynExpr* s_rhs = get_rhs()->s(); + Constant* l = dynamic_cast(s_lhs); + Constant* r = dynamic_cast(s_rhs); + // constant * constant + if (l && r) return DynExpr::_(l->get_val() * r->get_val()); + // 0 * X = 0 + if (l && l->get_val() == 0) return DynExpr::zero; + // 1 * X = X + if (l && l->get_val() == 1) return s_rhs; + // X * 1 = X + if (r && r->get_val() == 1) return s_lhs; + // X * constant = constant * X + if (r && s_lhs->is_dynamic()) return (r->get_val() * *s_lhs)->s(); + // m * (nX) = (m*n) * X + if (Mul* nX = dynamic_cast(s_rhs)) { + DynExpr* X = nX->get_rhs(); + Constant* n = dynamic_cast(nX->get_lhs()); + if (l && n) { + auto mn = l->get_val() * n->get_val(); + return (mn * *X)->s(); + } + } + return (*s_lhs) * (*s_rhs); +} + +DynExpr* Add::s() { + DynExpr* s_lhs = get_lhs()->s(); + DynExpr* s_rhs = get_rhs()->s(); + Constant* l = dynamic_cast(s_lhs); + Constant* r = dynamic_cast(s_rhs); + // constant + constant + if (l && r) return DynExpr::_(l->get_val() + r->get_val()); + // 0 + X = X + if (l && l->get_val() == 0) return s_rhs; + // X + 0 = X + if (r && r->get_val() == 0) return s_lhs; + // m + X = X + m + if (l && s_rhs->is_dynamic()) return (*s_rhs + l->get_val())->s(); + // X + X = 2 * X + if (*s_lhs == *s_rhs) { + return (2 * (*s_rhs))->s(); + } + // nX + X = (n+1) * X + if (Mul* nX = dynamic_cast(s_lhs)) { + DynExpr* n = nX->get_lhs(); + DynExpr* X = nX->get_rhs(); + if (*X == *s_rhs) { + return (*(*n + 1) * (*X))->s(); + } + } + // X + nX = (n+1) * X + if (Mul* nX = dynamic_cast(s_rhs)) { + DynExpr* n = nX->get_lhs(); + DynExpr* X = nX->get_rhs(); + if (*X == *s_lhs) { + return (*(*n + 1) * (*X))->s(); + } + } + // mX + nX = (m+n) * X + if (Mul* mX = dynamic_cast(s_lhs), *nY = dynamic_cast(s_rhs); + mX && nY) { + DynExpr* m = mX->get_lhs(); + DynExpr* X = mX->get_rhs(); + DynExpr* n = nY->get_lhs(); + DynExpr* Y = nY->get_rhs(); + if (*X == *Y) { + return (*(*m + *n) * (*X))->s(); + } + } + // (X + Y) + Z = X + (Y + Z) + if (Add* XY = dynamic_cast(s_lhs)) { + DynExpr* X = XY->get_lhs(); + DynExpr* Y = XY->get_rhs(); + return (*X + *(*Y + *s_rhs))->s(); + } + // (X - Y) + Z = X - (Y - Z) + if (Sub* XY = dynamic_cast(s_lhs)) { + DynExpr* X = XY->get_lhs(); + DynExpr* Y = XY->get_rhs(); + return (*X - *(*Y - *s_rhs))->s(); + } + return *s_lhs + *s_rhs; +} + +DynExpr* Sub::s() { + if (!get_lhs()){ + LOG(INFO) << "NO LEFT"; + } + + if (!get_rhs()){ + LOG(INFO) << "NO RIGHT"; + } + DynExpr* s_lhs = get_lhs()->s(); + DynExpr* s_rhs = get_rhs()->s(); + Constant* l = dynamic_cast(s_lhs); + Constant* r = dynamic_cast(s_rhs); + // constant - constant + if (l && r) return DynExpr::_(l->get_val() - r->get_val()); + // X - 0 = X + if (r && r->get_val() == 0) return s_lhs; + // X - X = 0 + if (*s_lhs == *s_rhs) { + return DynExpr::zero; + } + // mX - nX = (m-n) * X + if (Mul* mX = dynamic_cast(s_lhs), *nY = dynamic_cast(s_rhs); + mX && nY) { + DynExpr* m = mX->get_lhs(); + DynExpr* X = mX->get_rhs(); + DynExpr* n = nY->get_lhs(); + DynExpr* Y = nY->get_rhs(); + if (*X == *Y) { + return (*(*m - *n) * (*X))->s(); + } + } + // (X + Y) - X = X + (Y - Z) + if (Add* XY = dynamic_cast(s_lhs)) { + DynExpr* X = XY->get_lhs(); + DynExpr* Y = XY->get_rhs(); + return (*X + *(*Y - *s_rhs))->s(); + } + // (X - Y) - Z = X - (Y + Z) + if (Sub* XY = dynamic_cast(s_lhs)) { + DynExpr* X = XY->get_lhs(); + DynExpr* Y = XY->get_rhs(); + return (*X - *(*Y + *s_rhs))->s(); + } + return *s_lhs - *s_rhs; +} + +DynExpr* Div::s() { + DynExpr* s_lhs = get_lhs()->s(); + DynExpr* s_rhs = get_rhs()->s(); + Constant* l = dynamic_cast(s_lhs); + Constant* r = dynamic_cast(s_rhs); + // constant / constant + if (l && r) return DynExpr::_(l->get_val() / r->get_val()); + // X / 1 = X + if (r && r->get_val() == 1) return s_lhs; + // (X + Y) / Z = (X/Z) + (Y/Z) + if (Add* XY = dynamic_cast(s_lhs)) { + DynExpr* X = XY->get_lhs(); + DynExpr* Y = XY->get_rhs(); + return (*((*X) / (*s_rhs)) + *((*Y) / (*s_rhs)))->s(); + } + // (X * Y) / Z = (X/Z) * Y + if (Mul* XY = dynamic_cast(s_lhs)) { + DynExpr* X = XY->get_lhs(); + DynExpr* Y = XY->get_rhs(); + return (*(*X / (*s_rhs)) * (*Y))->s(); + } + // (X / Y) / Z = X / (Y*Z) + if (Div* XY = dynamic_cast(s_lhs)) { + DynExpr* X = XY->get_lhs(); + DynExpr* Y = XY->get_rhs(); + return (*X / *(*Y * *s_rhs))->s(); + } + return *s_lhs / *s_rhs; +} + +std::ostream& operator<<(std::ostream& os, DynExpr* expr) { + ExpressionProto proto; + expr->to_proto(&proto); + os << proto.ShortDebugString(); + return os; +} + +DynExpr* DynExpr::zero = new Constant(0); +DynExpr* DynExpr::one = new Constant(1); + // Defined in .cc file to avoid inlining these large routines Shape::Shape() = default; Shape::~Shape() = default; @@ -97,13 +386,18 @@ Shape::Shape(const ShapeProto& shape_proto) { } absl::StatusOr Shape::FromProto(const ShapeProto& shape_proto) { + + // LOG(INFO) << "FROM PROTO:\n" << shape_proto.DebugString() << std::endl; + Shape shape; shape.set_element_type(shape_proto.element_type()); if (auto* const state = shape.if_array_state()) { const int num_dims = shape_proto.dimensions_size(); const int num_is_dynamic_dims = shape_proto.is_dynamic_dimension_size(); + const int num_expressions = shape_proto.expressions_size(); state->dimensions.reserve(num_dims); state->dynamic_dimensions.reserve(num_dims); + state->expressions.reserve(num_dims); if (num_is_dynamic_dims != 0) { TF_RET_CHECK(num_dims == num_is_dynamic_dims) << "Malformed shape proto: number of is_dynamic_dimension " @@ -111,6 +405,13 @@ absl::StatusOr Shape::FromProto(const ShapeProto& shape_proto) { << num_is_dynamic_dims << ") does not match number of dimension " << "fields (" << num_dims << ")."; } + if (num_expressions != 0) { + TF_RET_CHECK(num_dims == num_expressions) + << "Malformed shape proto: number of expressions " + "fields (" + << num_expressions << ") does not match number of dimension " + << "fields (" << num_dims << ")."; + } for (int i = 0; i < num_dims; ++i) { const bool is_dynamic = (i < num_is_dynamic_dims) && shape_proto.is_dynamic_dimension(i); @@ -118,7 +419,11 @@ absl::StatusOr Shape::FromProto(const ShapeProto& shape_proto) { // UnsafeAddDimension. We expect that the caller will eventually call a // validation routine that will detect the error in case the dimension // value is invalid. - shape.UnsafeAddDimension(shape_proto.dimensions(i), is_dynamic); + DynExpr* expression = (i < num_expressions) + ? ExprFromProto(shape_proto.expressions(i)) + : DynExpr::_(-30); + shape.UnsafeAddDimension(shape_proto.dimensions(i), is_dynamic, + expression); } } else if (auto* const state = shape.if_tuple_state()) { state->tuple_shapes.reserve(shape_proto.tuple_shapes_size()); @@ -136,7 +441,7 @@ absl::StatusOr Shape::FromProto(const ShapeProto& shape_proto) { TF_ASSIGN_OR_RETURN(*shape.mutable_layout(), Layout::FromProto(shape_proto.layout())); } - shape.set_outer_multiplier(shape_proto.batch_multiplier()); + // LOG(INFO) << "FROM PROTO " << shape << "\n"; return shape; } @@ -144,6 +449,8 @@ ShapeProto Shape::ToProto() const { ShapeProto proto; proto.set_element_type(element_type_); + // LOG(INFO) << "TO PROTO " << ToString() << "\n"; + if (const auto* const state = if_array_state()) { proto.mutable_dimensions()->Reserve(state->dimensions.size()); for (const int64_t dimension : state->dimensions) { @@ -152,6 +459,11 @@ ShapeProto Shape::ToProto() const { for (const bool dynamic : state->dynamic_dimensions) { proto.add_is_dynamic_dimension(dynamic); } + for (const DynExpr* e : state->expressions) { + ExpressionProto* eproto = proto.add_expressions(); + CHECK(e != nullptr) << "Missing expression in expression list."; + e->to_proto(eproto); + } if (state->layout.has_value()) { *proto.mutable_layout() = state->layout->ToProto(); } @@ -164,7 +476,7 @@ ShapeProto Shape::ToProto() const { proto.mutable_tuple_shapes()->Reserve(1); *proto.add_tuple_shapes() = state->buffer_shape[0].ToProto(); } - proto.set_batch_multiplier(outer_multiplier()); + // LOG(INFO) << "DEBUG VIEW:\n" << proto.DebugString() << std::endl; return proto; } @@ -247,14 +559,14 @@ bool Shape::AreAllLeavesIntegers() const { return primitive_util::IsIntegralType(element_type()); } -void Shape::add_dimensions(int64_t value, bool is_dynamic) { +void Shape::add_dimensions(int64_t value, bool is_dynamic, DynExpr* expr) { if (value < 0) { CHECK(is_dynamic) << "static dimension must have size >= 0 instead of " << value << "."; CHECK_EQ(value, kUnboundedSize) << "dynamic dimension must have size == kUnboundedSize or >= 0."; } - UnsafeAddDimension(value, is_dynamic); + UnsafeAddDimension(value, is_dynamic, expr); } void Shape::set_dynamic_dimension(int dimension, bool is_dynamic) { @@ -264,6 +576,19 @@ void Shape::set_dynamic_dimension(int dimension, bool is_dynamic) { state.dynamic_dimensions[dimension] = is_dynamic; } +void Shape::set_expression(int dimension, DynExpr* e) { + auto& state = array_state(); + state.expressions[dimension] = e; +} + +void Shape::set_expressions(std::vector exps) { + auto& state = array_state(); + state.expressions.clear(); + for (auto e : exps){ + state.expressions.push_back(e); + } +} + void Shape::set_dimensions(int index, int64_t size, std::optional is_dynamic) { auto& state = array_state(); @@ -272,6 +597,7 @@ void Shape::set_dimensions(int index, int64_t size, CheckDimensionSize(index, size, dynamic); state.dimensions[index] = size; state.dynamic_dimensions[index] = dynamic; + state.expressions[index] = DynExpr::_(size); } void Shape::set_dimensions_minor(int index, int64_t size, @@ -293,12 +619,15 @@ void Shape::CheckDimensionSize(int dim_index, int64_t size, bool is_dynamic) { } } -void Shape::UnsafeAddDimension(int64_t value, bool is_dynamic) { +void Shape::UnsafeAddDimension(int64_t value, bool is_dynamic, DynExpr* exp) { auto& state = array_state(); CHECK_EQ(state.dimensions.size(), state.dynamic_dimensions.size()) << "where the shape is " << ToString(); + CHECK_EQ(state.dimensions.size(), state.expressions.size()) + << "where the shape is " << ToString(); state.dimensions.push_back(value); state.dynamic_dimensions.push_back(is_dynamic); + state.expressions.push_back(exp); } bool Shape::is_static() const { @@ -347,6 +676,8 @@ void Shape::DeleteDimension(int64_t dim_to_delete) { state.dimensions.erase(state.dimensions.begin() + dim_to_delete); state.dynamic_dimensions.erase(state.dynamic_dimensions.begin() + dim_to_delete); + state.expressions.erase(state.expressions.begin() + + dim_to_delete); if (LayoutUtil::HasLayout(*this)) { state.layout->DeleteDimension(dim_to_delete); // NOLINT: optional-access } @@ -360,6 +691,8 @@ void Shape::DeleteDimensions(absl::Span dims_to_delete) { state.dimensions = RemoveElements(sorted_dims_to_delete, state.dimensions); state.dynamic_dimensions = RemoveElements(sorted_dims_to_delete, state.dynamic_dimensions); + state.expressions = + RemoveElements(sorted_dims_to_delete, state.expressions); if (LayoutUtil::HasLayout(*this)) { for (auto it = sorted_dims_to_delete.rbegin(); it != sorted_dims_to_delete.rend(); ++it) { @@ -372,6 +705,7 @@ void Shape::CheckStateIsEmpty() const { if (const auto* const state = if_array_state()) { CHECK(state->dimensions.empty()) << ToString(); CHECK(state->dynamic_dimensions.empty()) << ToString(); + CHECK(state->expressions.empty()) << ToString(); CHECK(!state->layout.has_value()) << ToString(); } else if (const auto* const state = if_tuple_state()) { CHECK(state->tuple_shapes.empty()) << ToString(); @@ -511,7 +845,8 @@ bool Shape::Equal::operator()(const Shape& lhs, const Shape& rhs) { rhs.is_unbounded_dynamic_dimension(i))) { continue; } - if (i == 0 && ignore_batch_ && (lhs.outer_multiplier() > 0 || rhs.outer_multiplier() > 0)) { + if (i == 0 && ignore_batch_ && + (lhs.outer_multiplier() > 0 || rhs.outer_multiplier() > 0)) { VLOG(3) << "CompareShapes: batch dimension found. Forcely compatible"; continue; } diff --git a/third_party/xla/xla/shape.h b/third_party/xla/xla/shape.h index a53297d18ec6fc..7c535c2e1e769a 100644 --- a/third_party/xla/xla/shape.h +++ b/third_party/xla/xla/shape.h @@ -41,6 +41,258 @@ limitations under the License. namespace xla { +class DynExpr { + public: + virtual ~DynExpr() = default; + virtual void print(xla::Printer* printer) const = 0; + virtual void to_proto(xla::ExpressionProto* proto) const = 0; + virtual bool is_constant() const = 0; + virtual int64_t get_val() const { return -1; } + virtual DynExpr* s() = 0; // simplify + virtual DynExpr* substitute(int id, DynExpr* v) = 0; + + bool is_dynamic() { return !is_constant(); } + + static DynExpr* zero; + static DynExpr* one; + static DynExpr* _(int64_t val); + static DynExpr* V(int var_id); + static DynExpr* _s(DynExpr* expr); + static bool equal(DynExpr* expr1, DynExpr* expr2); + + friend std::ostream& operator<<(std::ostream& os, DynExpr* expr); +}; + +// constant i +class Constant : public DynExpr { + int64_t value; + + public: + explicit Constant(int64_t v) : value(v) {} + void print(xla::Printer* printer) const override { + if (value < 0) { + printer->Append("("); + } + printer->Append(value); + if (value < 0) { + printer->Append(")"); + } + } + void to_proto(xla::ExpressionProto* proto) const override { + proto->set_constant_value(value); + } + bool is_constant() const override { return true; } + int64_t get_val() const override { return value; } + DynExpr* substitute(int id, DynExpr* v) { return this; } + DynExpr* s() override; +}; + +// var id (int) +class Variable : public DynExpr { + int id; + + public: + explicit Variable(int identifier) : id(identifier) {} + void print(xla::Printer* printer) const override { + // printer->Append("(Var "); + char letter = 'A' + (id - 1); + printer->Append(std::string(1, letter)); + // printer->Append(")"); + } + void to_proto(xla::ExpressionProto* proto) const override { + proto->set_variable_id(id); + } + bool is_constant() const override { return false; } + int get_id() const { return id; } + DynExpr* substitute(int id, DynExpr* v) { return get_id() == id ? v : this;} + DynExpr* s() override; +}; + +// exp = exp + exp +class Add : public DynExpr { + DynExpr* lhs; + DynExpr* rhs; + + public: + Add(DynExpr* l, DynExpr* r) : lhs(std::move(l)), rhs(std::move(r)) {} + void print(xla::Printer* printer) const override { + printer->Append("("); + lhs->print(printer); + printer->Append(" + "); + rhs->print(printer); + printer->Append(")"); + } + + DynExpr* get_lhs() const { return lhs; } + DynExpr* get_rhs() const { return rhs; } + + void to_proto(xla::ExpressionProto* proto) const override { + auto* add_msg = proto->mutable_add_node(); + lhs->to_proto(add_msg->mutable_lhs()); + rhs->to_proto(add_msg->mutable_rhs()); + } + + bool is_constant() const override { + return lhs->is_constant() && rhs->is_constant(); + } + + int64_t get_val() const override { return lhs->get_val() + rhs->get_val(); } + + DynExpr* substitute(int id, DynExpr* v) { + return new Add(lhs->substitute(id, v), rhs->substitute(id, v)); + } + DynExpr* s() override; + + ~Add() { + delete lhs; + delete rhs; + } +}; + +// exp = exp - exp +class Sub : public DynExpr { + DynExpr* lhs; + DynExpr* rhs; + + public: + Sub(DynExpr* l, DynExpr* r) : lhs(std::move(l)), rhs(std::move(r)) {} + void print(xla::Printer* printer) const override { + printer->Append("("); + lhs->print(printer); + printer->Append(" - "); + rhs->print(printer); + printer->Append(")"); + } + + DynExpr* get_lhs() const { return lhs; } + DynExpr* get_rhs() const { return rhs; } + + void to_proto(xla::ExpressionProto* proto) const override { + auto* sub_msg = proto->mutable_sub_node(); + lhs->to_proto(sub_msg->mutable_lhs()); + rhs->to_proto(sub_msg->mutable_rhs()); + } + + bool is_constant() const override { + return lhs->is_constant() && rhs->is_constant(); + } + + int64_t get_val() const override { return lhs->get_val() - rhs->get_val(); } + + DynExpr* substitute(int id, DynExpr* v) { + return new Sub(lhs->substitute(id, v), rhs->substitute(id, v)); + } + + DynExpr* s() override; + + ~Sub() { + delete lhs; + delete rhs; + } +}; + +// exp = exp * exp +class Mul : public DynExpr { + DynExpr* lhs; + DynExpr* rhs; + + public: + Mul(DynExpr* l, DynExpr* r) : lhs(std::move(l)), rhs(std::move(r)) {} + void print(xla::Printer* printer) const override { + printer->Append("("); + lhs->print(printer); + printer->Append(" * "); + rhs->print(printer); + printer->Append(")"); + } + + DynExpr* get_lhs() const { return lhs; } + DynExpr* get_rhs() const { return rhs; } + + void to_proto(xla::ExpressionProto* proto) const override { + auto* mul_msg = proto->mutable_mul_node(); + lhs->to_proto(mul_msg->mutable_lhs()); + rhs->to_proto(mul_msg->mutable_rhs()); + } + + bool is_constant() const override { + return lhs->is_constant() && rhs->is_constant(); + } + + int64_t get_val() const override { return lhs->get_val() * rhs->get_val(); } + + DynExpr* substitute(int id, DynExpr* v) { + return new Mul(lhs->substitute(id, v), rhs->substitute(id, v)); + } + + DynExpr* s() override; + + ~Mul() { + delete lhs; + delete rhs; + } +}; + +// expr / expr +class Div : public DynExpr { + DynExpr* lhs; + DynExpr* rhs; + + public: + Div(DynExpr* l, DynExpr* r) : lhs(std::move(l)), rhs(std::move(r)) {} + void print(xla::Printer* printer) const override { + printer->Append("("); + lhs->print(printer); + printer->Append(" / ( "); + rhs->print(printer); + printer->Append(") )"); + } + + DynExpr* get_lhs() const { return lhs; } + DynExpr* get_rhs() const { return rhs; } + + void to_proto(xla::ExpressionProto* proto) const override { + auto* div_msg = proto->mutable_div_node(); + lhs->to_proto(div_msg->mutable_lhs()); + rhs->to_proto(div_msg->mutable_rhs()); + } + + bool is_constant() const override { + return lhs->is_constant() && rhs->is_constant(); + } + + int64_t get_val() const override { return lhs->get_val() / rhs->get_val(); } + + DynExpr* substitute(int id, DynExpr* v) { + return new Div(lhs->substitute(id, v), rhs->substitute(id, v)); + } + + DynExpr* s() override; + + ~Div() { + delete lhs; + delete rhs; + } +}; + +DynExpr* operator*(DynExpr& lhs, DynExpr& rhs); +DynExpr* operator*(int64_t k, DynExpr& rhs); +DynExpr* operator/(DynExpr& lhs, DynExpr& rhs); +DynExpr* operator/(DynExpr& lhs, int64_t d); +DynExpr* operator+(DynExpr& lhs, DynExpr& rhs); +DynExpr* operator+(DynExpr& lhs, int64_t d); +DynExpr* operator-(DynExpr& lhs, DynExpr& rhs); +DynExpr* operator-(DynExpr& lhs, int64_t d); +bool operator==(DynExpr& lhs, DynExpr& rhs); +bool operator==(DynExpr& lhs, int64_t d); + +inline DynExpr* DynExpr::_(int64_t val) { + if (val == 0) return DynExpr::zero; + if (val == 1) return DynExpr::one; + return new Constant(val); +} +inline DynExpr* DynExpr::V(int var_id) { return new Variable(var_id); } + // A shape describes the number of dimensions in a array, the bounds of each // dimension, and the primitive component type. For tuples, shape describes the // structure (number of elements and nesting). @@ -218,6 +470,23 @@ class Shape { return array_state().dynamic_dimensions[dimension]; } + bool has_dynamic_expr() const { + if (auto* const state = if_array_state()) { + return absl::c_any_of(state->expressions, + [](DynExpr* e) { return e->is_dynamic(); }); + } + if (auto* const state = if_tuple_state()) { + return absl::c_any_of(state->tuple_shapes, [](Shape subshape) { + return subshape.has_dynamic_expr(); + }); + } + return false; + } + + DynExpr* expressions(int dimension) const { + return array_state().expressions[dimension]; + } + // Returns true if the given dimension is statically-sized. // Precondition: this is an array shape and `dimension` is a valid dimension // index. @@ -232,12 +501,20 @@ class Shape { // - The dimension's size is valid for the given dynamic-ness. void set_dynamic_dimension(int dimension, bool is_dynamic); + void set_expression(int dimension, DynExpr* e); + + void set_expressions(std::vector exprs); + // Returns a span to indicate whether each dimension is dynamic. // Precondition: this is an array shape. absl::Span dynamic_dimensions() const { return array_state().dynamic_dimensions; } + absl::Span expressions() const { + return array_state().expressions; + } + // Removes the given dimension from the shape. Layout, if it exists, is // adjusted to match the modified shape. // Precondition: this is an array shape, and the input dimension indices are @@ -313,7 +590,8 @@ class Shape { // - This is an array shape. // - Either `value` is >= 0, or `is_dynamic` is true and `value` is // kUnboundedSize. - void add_dimensions(int64_t value, bool is_dynamic = false); + void add_dimensions(int64_t value, bool is_dynamic = false, + xla::DynExpr* expr = nullptr); // Clears all dimensions (i.e. makes this shape a scalar). // Precondition: this is an array shape. @@ -321,6 +599,7 @@ class Shape { auto& state = array_state(); state.dimensions.clear(); state.dynamic_dimensions.clear(); + state.expressions.clear(); } // Returns a span to indicate the size of each dimension. @@ -520,7 +799,7 @@ class Shape { } if (const auto* const state = s.if_array_state()) { h = H::combine(std::move(h), s.element_type_, state->dimensions, - state->dynamic_dimensions); + state->dynamic_dimensions, state->expressions); if (kIsLayoutSensitive) { h = H::combine(std::move(h), state->layout); } @@ -541,6 +820,7 @@ class Shape { void set_outer_multiplier(int64_t m) { outer_multiplier_ = m; } private: int64_t outer_multiplier_ = -1; + friend absl::Status ValidateNonLayoutProperties(const Shape& shape); // Define one state struct for each shape category. Depending on the element @@ -568,6 +848,8 @@ class Shape { // respective dimension is dynamically sized. absl::InlinedVector dynamic_dimensions; + absl::InlinedVector expressions; + // The layout of the shape. std::optional layout; }; @@ -591,7 +873,8 @@ class Shape { // Instead, we rely on validation down the road to catch invalid shapes. // This is useful for code that should not crash, such as constructing a // Shape from an unvalidated proto. - void UnsafeAddDimension(int64_t value, bool is_dynamic); + void UnsafeAddDimension(int64_t value, bool is_dynamic, + DynExpr* exp = nullptr); // Convenience accessors for the state_ variant. Each if_*_state() accessor // returns a pointer to the corresponding state struct, or nullptr if the diff --git a/third_party/xla/xla/shape_util.cc b/third_party/xla/xla/shape_util.cc index e0a327e29e57da..de0ab94330ac67 100644 --- a/third_party/xla/xla/shape_util.cc +++ b/third_party/xla/xla/shape_util.cc @@ -267,9 +267,20 @@ static std::vector MakeDynamicDimensions( return dynamic_dimensions; } +static std::vector MakeExpressions( + absl::Span dimensions) { + std::vector expressions; + expressions.reserve(dimensions.size()); + for (int64_t d : dimensions) { + expressions.push_back(DynExpr::_(d)); + } + return expressions; +} + /* static */ Shape ShapeUtil::MakeShape(PrimitiveType element_type, - absl::Span dimensions) { - return MakeValidatedShape(element_type, dimensions).value(); + absl::Span dimensions, + absl::Span expressions) { + return MakeValidatedShape(element_type, dimensions, expressions).value(); } /* static */ Shape ShapeUtil::MakeScalarShape(PrimitiveType element_type) { @@ -278,8 +289,10 @@ static std::vector MakeDynamicDimensions( /* static */ Shape ShapeUtil::MakeShape( PrimitiveType element_type, absl::Span dimensions, - const std::vector& dynamic_dimensions) { - return MakeValidatedShape(element_type, dimensions, dynamic_dimensions) + const std::vector& dynamic_dimensions, + absl::Span expressions) { + return MakeValidatedShape(element_type, dimensions, dynamic_dimensions, + expressions) .value(); } @@ -296,20 +309,27 @@ static std::vector MakeDynamicDimensions( } /* static */ absl::StatusOr ShapeUtil::MakeValidatedShape( - PrimitiveType element_type, absl::Span dimensions) { - return MakeValidatedShape(element_type, dimensions, - MakeDynamicDimensions(dimensions)); + PrimitiveType element_type, absl::Span dimensions, + absl::Span expressions) { + return MakeValidatedShape( + element_type, dimensions, MakeDynamicDimensions(dimensions), + expressions.empty() ? MakeExpressions(dimensions) : expressions); } /* static */ absl::StatusOr ShapeUtil::MakeValidatedShape( PrimitiveType element_type, absl::Span dimensions, - const std::vector& dynamic_dimensions) { + const std::vector& dynamic_dimensions, + absl::Span expressions) { if (dynamic_dimensions.size() != dimensions.size()) { return InvalidArgument( "dynamic dimensions size %d did not match number of dimensions %d", dynamic_dimensions.size(), dimensions.size()); } - + if (expressions.size() != dimensions.size()) { + return InvalidArgument( + "expressions size %d did not match number of dimensions %d", + expressions.size(), dimensions.size()); + } Shape shape; int64_t dense_shape_size = primitive_util::IsArrayType(element_type) ? primitive_util::ByteWidth(element_type) @@ -328,6 +348,7 @@ static std::vector MakeDynamicDimensions( for (int i = 0; i < ndims; i++) { const int64_t d = dimensions[i]; const bool is_dynamic = dynamic_dimensions[i]; + DynExpr* expression = expressions[i]; if (!Shape::IsValidDimensionSize(d, is_dynamic)) { return InvalidArgument("Invalid dimension size %d, is_dynamic=%s", d, is_dynamic ? "true" : "false"); @@ -339,7 +360,7 @@ static std::vector MakeDynamicDimensions( any_overflows |= overflow; } - shape.add_dimensions(d, is_dynamic); + shape.add_dimensions(d, is_dynamic, expression); minor_to_major->push_back(ndims - 1 - i); } @@ -408,10 +429,14 @@ static std::vector MakeDynamicDimensions( } /* static */ Shape ShapeUtil::MakeShapeWithDescendingLayout( - PrimitiveType element_type, absl::Span dimensions) { + PrimitiveType element_type, absl::Span dimensions, + absl::Span expressions) { std::vector layout(dimensions.size()); std::iota(layout.rbegin(), layout.rend(), static_cast(0)); - return MakeShapeWithDenseLayout(element_type, dimensions, layout); + auto shape = MakeShapeWithDenseLayout(element_type, dimensions, layout); + std::vector exprs(expressions.begin(), expressions.end()); + shape.set_expressions(exprs); + return shape; } /* static */ Shape @@ -442,6 +467,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( dim = LayoutUtil::Major(shape.layout(), dim); } new_shape.set_dynamic_dimension(i, shape.is_dynamic_dimension(dim)); + new_shape.set_expression(i, shape.expressions(dim)); } new_shape.mutable_layout()->set_memory_space(shape.layout().memory_space()); return new_shape; @@ -731,7 +757,22 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { printer->Append("?"); } } else { + // Only print constant expression if it is different than the dimension + // (i.e. it is wrong!) + bool is_wrong = shape.expressions(i)->is_constant() && + shape.expressions(i)->get_val() != shape.dimensions(i); printer->Append(shape.dimensions(i)); + if (is_wrong) { + LOG(ERROR) << "THIS SHOULD NEVER HAPPEN! " << shape.ToString(); + printer->Append("print(printer); + printer->Append("!>"); + } + if (shape.expressions(i) && (shape.expressions(i)->is_dynamic())) { + printer->Append("<"); + shape.expressions(i)->print(printer); + printer->Append(">"); + } } }; print_dimension(0); @@ -880,6 +921,11 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { return shape.dimensions(GetDimensionNumber(shape, dimension_number)); } +/* static */ DynExpr* ShapeUtil::GetExpression(const Shape& shape, + int64_t dimension_number) { + return shape.expressions(GetDimensionNumber(shape, dimension_number)); +} + /* static */ int64_t ShapeUtil::GetDimensionNumber(const Shape& shape, int64_t dimension_number) { if (dimension_number < 0) { @@ -1216,8 +1262,11 @@ ShapeUtil::PackedFactorFor1DInterleavedArray(const Shape& shape) { const auto permuted_dims = Permute(shape.dimensions(), permutation); const auto permuted_dynamic_dims = Permute(shape.dynamic_dimensions(), permutation); + const auto permuted_expressions = + Permute(shape.expressions(), permutation); for (int i = 0; i < permuted_dims.size(); ++i) { - new_shape.add_dimensions(permuted_dims[i], permuted_dynamic_dims[i]); + new_shape.add_dimensions(permuted_dims[i], permuted_dynamic_dims[i], + permuted_expressions[i]); } // If `shape` has a layout, by contract we choose a new layout such that the diff --git a/third_party/xla/xla/shape_util.h b/third_party/xla/xla/shape_util.h index 4e3602671fa153..7e9ed98719ff4f 100644 --- a/third_party/xla/xla/shape_util.h +++ b/third_party/xla/xla/shape_util.h @@ -330,6 +330,10 @@ class ShapeUtil { // GetDimensionNumber(dimension_number). static int64_t GetDimension(const Shape& shape, int64_t dimension_number); + // Extracts the shape's expressions at dimension number + // GetDimensionNumber(dimension_number). + static DynExpr* GetExpression(const Shape& shape, int64_t dimension_number); + // Resolves a dimension number, supporting negative indexing. // // Negative indexing has similar semantics to Python. For an N-dimensional @@ -406,7 +410,8 @@ class ShapeUtil { // Constructs a new shape with the given element type and sequence of // dimensions. static Shape MakeShape(PrimitiveType element_type, - absl::Span dimensions); + absl::Span dimensions, + absl::Span expressions = {}); // Make a scalar shape with given primitive type. static Shape MakeScalarShape(PrimitiveType element_type); @@ -419,7 +424,8 @@ class ShapeUtil { // the same size. static Shape MakeShape(PrimitiveType element_type, absl::Span dimensions, - const std::vector& dynamic_dimensions); + const std::vector& dynamic_dimensions, + absl::Span expressions); // Constructs a new buffer shape with the given element type, and sequence of // dimensions. static Shape MakeBufferShape(PrimitiveType element_type, @@ -430,10 +436,13 @@ class ShapeUtil { // size fits in std::numeric_limits::max(), and dynamic size is not // marked static. static absl::StatusOr MakeValidatedShape( - PrimitiveType element_type, absl::Span dimensions); + PrimitiveType element_type, absl::Span dimensions, + absl::Span expressions = {}); + static absl::StatusOr MakeValidatedShape( PrimitiveType element_type, absl::Span dimensions, - const std::vector& dynamic_dimensions); + const std::vector& dynamic_dimensions, + absl::Span expressions = {}); // Creates a Shape with element type corresponding to T and the given // dimensions @@ -476,7 +485,8 @@ class ShapeUtil { // Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}). static Shape MakeShapeWithDescendingLayout( - PrimitiveType element_type, absl::Span dimensions); + PrimitiveType element_type, absl::Span dimensions, + absl::Span expressions = {}); // Returns a new Shape based on the given Shape with low-dimension-major // layout (i.e. {n, n-1, ..., 0}, like Fortran), and with the dimensions diff --git a/third_party/xla/xla/stream_executor/tpu/c_api_decl.h b/third_party/xla/xla/stream_executor/tpu/c_api_decl.h index 0f3ba256573de1..e06db894249632 100644 --- a/third_party/xla/xla/stream_executor/tpu/c_api_decl.h +++ b/third_party/xla/xla/stream_executor/tpu/c_api_decl.h @@ -247,6 +247,8 @@ typedef struct XLA_Shape { int element_type; Int64List dimensions; BoolList dynamic_dimensions; + Int64List batch_multipliers; + Int64List batch_offsets; struct XLA_Shape* tuple_shapes; // owned int ntuple_shapes; bool has_layout; diff --git a/third_party/xla/xla/xla_data.proto b/third_party/xla/xla/xla_data.proto index 85ad79ae841078..798dc12fb5734e 100644 --- a/third_party/xla/xla/xla_data.proto +++ b/third_party/xla/xla/xla_data.proto @@ -361,7 +361,7 @@ message ShapeProto { // The layout used to back this shape. LayoutProto layout = 5; - int64 batch_multiplier = 7; + repeated ExpressionProto expressions = 7; // Important: if any field is added, be sure to modify ShapeUtil::Equal(), // ShapeUtil::Compatible() and Shape::Hash() appropriately to account for the @@ -1206,3 +1206,34 @@ message OriginalArrayProto { message OriginalValueProto { repeated OriginalArrayProto leaves = 1; } + +message ExpressionProto { + oneof node_type { + int32 constant_value = 1; // cons + int32 variable_id = 2; // var + AddNode add_node = 3; // exp + exp + SubNode sub_node = 4; // exp - exp + MulNode mul_node = 5; // exp * exp + DivNode div_node = 6; // exp / exp + } +} + +message AddNode { + ExpressionProto lhs = 1; + ExpressionProto rhs = 2; +} + +message SubNode { + ExpressionProto lhs = 1; + ExpressionProto rhs = 2; +} + +message MulNode { + ExpressionProto lhs = 1; + ExpressionProto rhs = 2; +} + +message DivNode { + ExpressionProto lhs = 1; + ExpressionProto rhs = 2; +} \ No newline at end of file From bed9d0ff05e7014130a1fe308953c596ec62b4c5 Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Sat, 21 Mar 2026 11:34:36 +0000 Subject: [PATCH 06/16] Add dynamic value capture and clustering guardrails Captures dynamic dimension values at compile time and adds clustering safeguards so incompatible dynamic-shape regions are not merged accidentally. Co-authored-by: Steven Varoumas Co-authored-by: Utku Saglam Co-authored-by: Jinyun (Joey) Ye Co-authored-by: Muteages <67578152+Muteages@users.noreply.github.com> Co-authored-by: Guillermo Callaghan --- tensorflow/compiler/jit/flags.cc | 5 +- tensorflow/compiler/jit/flags.h | 2 + .../compiler/jit/mark_for_compilation_pass.cc | 234 +++++++++++ tensorflow/compiler/jit/xla_batch_matcher.cc | 8 +- tensorflow/core/framework/tensor_shape.h | 9 +- .../core/grappler/costs/graph_properties.cc | 51 ++- third_party/xla/xla/BUILD | 1 + third_party/xla/xla/shape.h | 263 +----------- third_party/xla/xla/shape_dynexpr.h | 379 ++++++++++++++++++ 9 files changed, 664 insertions(+), 288 deletions(-) create mode 100644 third_party/xla/xla/shape_dynexpr.h diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index 5d8a1904ab49a2..41a2f89f49be55 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -110,10 +110,13 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector* flag_list) { "Maximum number of operators in an XLA compilation."), Flag("tf_xla_annotate_cluster_id", &mark_for_compilation_flags->tf_xla_annotate_cluster_id, - "Allow operator names to influence clustering scheume." + "Allow operator names to influence clustering scheme." "Operators whose name starting with .cluster.{id} will likely" "to be clustered together if the ids are the same number. " ".cluster.none will not be clustered with those having numbered id"), + Flag("tf_xla_cluster_single_dynamic_dim", + &mark_for_compilation_flags->tf_xla_cluster_single_dynamic_dim, + "Only allow clustering of a single dynamic dimension."), Flag("tf_xla_cluster_parallel", &mark_for_compilation_flags->tf_xla_cluster_parallel, "Split parallel compute subgraph info different clusters"), diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index 971dd8a7a38229..93fdc016860654 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -125,6 +125,8 @@ struct MarkForCompilationPassFlags { // Setting it to -1 disables marking clusters megamorphic. // Setting it to 0 uses the default behaviour of TensorFlow. int64_t tf_xla_threshold_for_megamorphic; + + bool tf_xla_cluster_single_dynamic_dim; // New flag for single dynamic dim clustering }; // Flags associated with XLA Sparse Core. diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 0b33abc4612931..82b920e9b14019 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -135,6 +135,8 @@ class MarkForCompilationPassImpl { int annotate_cluster_id; bool enable_cluster_parallel; + + bool cluster_single_dynamic_dim; // New flag to control single dynamic dim clustering }; MarkForCompilationPassImpl(DebugOptions debug_options, Graph* graph, @@ -261,9 +263,12 @@ class MarkForCompilationPassImpl { void set_annotated_id(int id) { annotated_id_ = id; } int chain_id() const {return chain_id_;} void set_chain_id(int id) {chain_id_ = id;} + void add_dim_var(int dim_var) { dim_vars_.insert(dim_var); } + const std::set& dim_vars() const { return dim_vars_; } private: int annotated_id_ = -1; + std::set dim_vars_; int chain_id_ = -1; int cluster_size_ = 1; int cycles_graph_node_id_; @@ -337,6 +342,7 @@ class MarkForCompilationPassImpl { } absl::Status AssignAnnotatedClusterIDs(); + absl::Status AssignDimVars(); void collectInputNodes(std::set &path_nodes); void collectMergeNodes(const std::vector& nodeSet, std::set &merger_nodes); @@ -680,6 +686,166 @@ absl::Status IgnoreResourceOpForSafetyAnalysis( } return absl::OkStatus(); } +// node mapping to multiple vectors of expressions (one for each output in +// order) +static std::map>>> + expr_map; +// Helper to convert ExpressionProto to a readable string. +std::string ExprProtoToString(const ExpressionProto& e) { + switch (e.node_type_case()) { + case ExpressionProto::kConstantValue: + return std::to_string(e.constant_value()); + case ExpressionProto::kVariableId: + return absl::StrCat("Var(", e.variable_id(), ")"); + case ExpressionProto::kAddNode: + return absl::StrCat("(", ExprProtoToString(e.add_node().lhs()), " + ", + ExprProtoToString(e.add_node().rhs()), ")"); + case ExpressionProto::kSubNode: + return absl::StrCat("(", ExprProtoToString(e.sub_node().lhs()), " - ", + ExprProtoToString(e.sub_node().rhs()), ")"); + case ExpressionProto::kMulNode: + return absl::StrCat("(", ExprProtoToString(e.mul_node().lhs()), " * ", + ExprProtoToString(e.mul_node().rhs()), ")"); + case ExpressionProto::kDivNode: + return absl::StrCat("(", ExprProtoToString(e.div_node().lhs()), " / ", + ExprProtoToString(e.div_node().rhs()), ")"); + default: + return ""; + } +} + +std::unique_ptr ExprFromProto(const ExpressionProto& proto) { + switch (proto.node_type_case()) { + case ExpressionProto::kConstantValue: + return DimExpr::Cons(proto.constant_value()); + case ExpressionProto::kVariableId: + return DimExpr::Var(proto.variable_id()); + case ExpressionProto::kAddNode: { + auto lhs = ExprFromProto(proto.add_node().lhs()); + auto rhs = ExprFromProto(proto.add_node().rhs()); + // Note: These are owning pointers, but ExprAdd takes raw pointers. + // The caller must manage lifetime appropriately. + return std::make_unique(lhs.release(), rhs.release()); + } + case ExpressionProto::kSubNode: { + auto lhs = ExprFromProto(proto.sub_node().lhs()); + auto rhs = ExprFromProto(proto.sub_node().rhs()); + return std::make_unique(lhs.release(), rhs.release()); + } + case ExpressionProto::kMulNode: { + auto lhs = ExprFromProto(proto.mul_node().lhs()); + auto rhs = ExprFromProto(proto.mul_node().rhs()); + return std::make_unique(lhs.release(), rhs.release()); + } + case ExpressionProto::kDivNode: { + auto lhs = ExprFromProto(proto.div_node().lhs()); + auto rhs = ExprFromProto(proto.div_node().rhs()); + return std::make_unique(lhs.release(), rhs.release()); + } + case ExpressionProto::NODE_TYPE_NOT_SET: + default: + return nullptr; + } +} + +static xla::DynExpr* DimExprToDynExpr(const DimExpr* e) { + switch (e->kind()) { + case DimExpr::Kind::kConstant: { + auto* ac = static_cast(e); + return xla::DynExpr::_(ac->value()); + } + case DimExpr::Kind::kVariable: { + auto* av = static_cast(e); + return xla::DynExpr::V(av->id()); // Use 1 all the time for now + } + case DimExpr::Kind::kAdd: { + auto* ee = static_cast(e); + return *DimExprToDynExpr(ee->lhs()) + *DimExprToDynExpr(ee->rhs()); + } + case DimExpr::Kind::kSub: { + auto* ee = static_cast(e); + return *DimExprToDynExpr(ee->lhs()) - *DimExprToDynExpr(ee->rhs()); + } + case DimExpr::Kind::kMul: { + auto* ee = static_cast(e); + return *DimExprToDynExpr(ee->lhs()) * *DimExprToDynExpr(ee->rhs()); + } + case DimExpr::Kind::kDiv: { + auto* ee = static_cast(e); + return *DimExprToDynExpr(ee->lhs()) / *DimExprToDynExpr(ee->rhs()); + } + } + return nullptr; +} + +// Runs Grappler static inference and logs any ExpressionProto found in output +// tensor shapes (from GraphProperties, not from _output_shapes attrs). +void LogExpressionsViaGraphProperties(const tensorflow::Graph& graph) { + using tensorflow::ExpressionProto; + using tensorflow::GraphDef; + using tensorflow::NodeDef; + using tensorflow::TensorShapeProto; + using tensorflow::grappler::GraphProperties; + using tensorflow::grappler::GrapplerItem; + + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + + GrapplerItem item; + item.id = "mark_for_compilation_pass_expr_dump"; + item.graph = graph_def; + + GraphProperties props(item); + + absl::Status st = props.InferStatically( + /*assume_valid_feeds=*/false, + /*aggressive_shape_inference=*/false, + /*include_input_tensor_values=*/false, + /*include_output_tensor_values=*/false); + + if (!st.ok()) { + LOG(ERROR) << "[EXPR][GP] InferStatically failed: " << st.message(); + return; + } + + int found = 0; + VLOG(1) << "[EXPR][GP] === GraphProperties output expr dump ==="; + + + for (const NodeDef& n : graph_def.node()) { + if (!props.HasOutputProperties(n.name())) continue; + const auto& outs = props.GetOutputProperties(n.name()); + std::vector>> list_exprs(outs.size()); + for (int out_idx = 0; out_idx < static_cast(outs.size()); ++out_idx) { + const auto& tp = outs[out_idx]; + const TensorShapeProto& shp = tp.shape(); + + if (shp.unknown_rank()) continue; + std::vector> exprs; + for (int d = 0; d < shp.dim_size(); ++d) { + const auto& dim = shp.dim(d); + + const ExpressionProto& expr = dim.expr(); + if (expr.node_type_case() == ExpressionProto::NODE_TYPE_NOT_SET) + continue; + + VLOG(1) << "Node " << n.name() << " has expression " + << ExprProtoToString(expr); + + auto ex = ExprFromProto(expr); + exprs.push_back(std::move(ex)); + + ++found; + } + list_exprs[out_idx] = std::move(exprs); + } + expr_map[n.name()] = std::move(list_exprs); + + } + + VLOG(1) << "[EXPR][GP] === Found " << found + << " expressions via GraphProperties ==="; +} absl::StatusOr MarkForCompilationPassImpl::Initialize() { TF_RET_CHECK(!initialized_ && !edges_contracted_ && !clusters_created_); @@ -721,6 +887,10 @@ absl::StatusOr MarkForCompilationPassImpl::Initialize() { if (debug_options_.annotate_cluster_id) { TF_RETURN_IF_ERROR(AssignAnnotatedClusterIDs()); } + if (debug_options_.cluster_single_dynamic_dim) { + LogExpressionsViaGraphProperties(*graph_); + TF_RETURN_IF_ERROR(AssignDimVars()); + } if (debug_options_.enable_cluster_parallel) { TF_RETURN_IF_ERROR(AssignParallelChains()); } @@ -1626,6 +1796,56 @@ absl::Status MarkForCompilationPassImpl::AssignAnnotatedClusterIDs(void) { return absl::OkStatus(); } +absl::Status MarkForCompilationPassImpl::AssignDimVars(void) { + for (Node* node : graph_->nodes()) { + auto node_name = node->name(); + Cluster * cluster = GetClusterForNode(node); + if (!cluster) continue; + for (const tensorflow::Edge* edge : node->in_edges()) { + if (edge->IsControlEdge()) { + // Skip control edges if you are only interested in data edges + continue; + } + + const tensorflow::Node* input = edge->src(); // Source node of the edge + auto it = expr_map.find(input->name()); + if (it == expr_map.end()) { + VLOG(2) << "No expression found for node " << input->name(); + continue; + } + + auto output_index = edge->src_output(); // Output index of the source node + if (output_index >= (it->second).size()) { + LOG(INFO) << "Warning: Output index " << output_index << " is out of bounds for node " << input->name(); + continue; + } + for (auto& pDim: (it->second)[output_index]) { + DimExpr * d= pDim.get(); + xla::DynExpr * dyn = DimExprToDynExpr(d); + auto new_ids = dyn->get_all_ids(); + for (auto id : new_ids) { + cluster->add_dim_var(id); + VLOG(2) << "Add dim var " << id << " to cluster of node "<< node_name; + } + } + } + // create a for loop for each dim vars in cluster and print each dim var + if (VLOG_IS_ON(2)) { + if (cluster->dim_vars().empty()) { + VLOG(2) << "Cluster of node " << node_name << " has no dim vars."; + } + else { + std::string id_str; + for (auto id : cluster->dim_vars()) { + id_str += "Dim var " + std::to_string(id) + ", "; + } + VLOG(2) << "Cluster of node " << node_name << " has dim vars:\n" << id_str; + } + } + } + return absl::OkStatus(); +} + bool MarkForCompilationPassImpl::LogNotContractableAndReturnFalse( Cluster* from, Cluster* to, absl::string_view reason) { VLOG(3) << EdgeContractionFailureMsg(from, to, reason); @@ -1852,6 +2072,18 @@ absl::StatusOr MarkForCompilationPassImpl::TryToContractEdge( from, to, "the two nodes do not have same annotated ids"); } + if (debug_options_.cluster_single_dynamic_dim) { + if (from->dim_vars().size() > 1 || to->dim_vars().size() > 1) { + return LogNotContractableAndReturnFalse( + from, to, "the two nodes have multiple dynamic dimensions"); + } + if (from->dim_vars().size() == 1 && to->dim_vars().size() == 1 && + from->dim_vars() != to->dim_vars()) { + return LogNotContractableAndReturnFalse( + from, to, "the two nodes have different dynamic dimensions"); + } + } + TF_ASSIGN_OR_RETURN(bool devices_compatible, AreDevicesCompatible(*from, *to)); if (!devices_compatible) { @@ -2248,6 +2480,8 @@ absl::Status MarkForCompilationPass::Run( debug_options.dump_graphs = flags->tf_xla_clustering_debug; debug_options.annotate_cluster_id = flags->tf_xla_annotate_cluster_id; debug_options.enable_cluster_parallel = flags->tf_xla_cluster_parallel; + debug_options.cluster_single_dynamic_dim = + flags->tf_xla_cluster_single_dynamic_dim; // Updated option name return MarkForCompilation(options, debug_options); } diff --git a/tensorflow/compiler/jit/xla_batch_matcher.cc b/tensorflow/compiler/jit/xla_batch_matcher.cc index 088eef4139855a..3f3641568996d9 100644 --- a/tensorflow/compiler/jit/xla_batch_matcher.cc +++ b/tensorflow/compiler/jit/xla_batch_matcher.cc @@ -90,13 +90,13 @@ void XlaBatchMatcher::print_all_batches() { void XlaBatchMatcher::parse_env_config() { // If the env var not set or is empty, filled with the nearest power of two by default if (!env_str_) { - LOG(INFO) << "[XLA_BATCH_WARN] Env var " << "--tf_xla_compile_batch_sizes" << + VLOG(2) << "[XLA_BATCH_WARN] Env var " << "--tf_xla_compile_batch_sizes" << " not set, filled with the nearest power of two by default"; return; } if (std::string(env_str_).empty()) { - LOG(INFO) << "[XLA_BATCH_WARN] Env var " << "--tf_xla_compile_batch_sizes" << + VLOG(2) << "[XLA_BATCH_WARN] Env var " << "--tf_xla_compile_batch_sizes" << "is empty, filled with the nearest power of two by default"; return; } @@ -176,10 +176,10 @@ int64_t XlaBatchMatcher::get_xla_compile_batch(int64_t real_batch) { int64_t selected = find_min_larger_batch(real_batch); if (real_batch != last_batch_ || all_batches_.empty()) { last_batch_ = real_batch; - LOG(INFO) << "[XLA_BATCH_INFO] Real batch: " << real_batch + VLOG(2) << "[XLA_BATCH_INFO] Real batch: " << real_batch << " -> Selected compile batch: " << selected; } return selected; } -} // namespace tensorflow \ No newline at end of file +} // namespace tensorflow diff --git a/tensorflow/core/framework/tensor_shape.h b/tensorflow/core/framework/tensor_shape.h index f271714fc989d6..31fb37b24e777d 100644 --- a/tensorflow/core/framework/tensor_shape.h +++ b/tensorflow/core/framework/tensor_shape.h @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/statusor.h" -#include "xla/shape.h" +#include "xla/shape_dynexpr.h" namespace tensorflow { @@ -95,10 +95,13 @@ class TensorShapeRep { // Return the multiplier for a specific dynamic dimension. // -1 if the dimension is not dynamic. xla::DynExpr* get_expression(int64_t dimension) const { - if (dimension >= expressions_.size()) { + // Guard against negative indices and avoid signed/unsigned comparison + if (dimension < 0) return nullptr; + const size_t dim = static_cast(dimension); + if (dim >= expressions_.size()) { return nullptr; } - return expressions_[dimension]; + return expressions_[dim]; } protected: diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index cecf7459ab5f4a..6d6a3a8568ffd8 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -1996,38 +1996,37 @@ class SymbolicShapeRefiner { for (int out = 0; out < ic->num_outputs(); ++out) { ShapeHandle s = ic->output(out); - if (!ic->RankKnown(s) && node->op() == "_Arg") { - // Treat batched function arguments as vectors with batch dim at dim0. - DimensionHandle d0 = GetUnknownOutputDim(node, out, /*dim_index=*/0); - ShapeHandle vec = ic->MakeShape({d0}); - ic->set_output(out, vec); - s = vec; - } - - if (!ic->RankKnown(s)){ - //if Rank is not realized yet, get it from attr. + if (!ic->RankKnown(s)) { auto it = node->attr().find("_output_shapes"); - if (it != node->attr().end() && it->second.list().shape_size()>0){ - const TensorShapeProto& proto = it->second.list().shape(out); + if (it == node->attr().end() || out >= it->second.list().shape_size()) { + VLOG(1) << "RANK still unknown. " << node->name(); + continue; + } - std::vector dims; - dims.reserve(proto.dim_size()); + const TensorShapeProto& proto = it->second.list().shape(out); + if (proto.unknown_rank()) { + continue; + } - for(int d=0; d=0){ - dims.push_back(ic->MakeDim(size)); - } else { - dims.push_back(GetUnknownOutputDim(node, out, d)); - } + std::vector dims; + dims.reserve(proto.dim_size()); + + for (int d = 0; d < proto.dim_size(); ++d) { + int64_t size = proto.dim(d).size(); + if (size >= 0) { + dims.push_back(ic->MakeDim(size)); + } else { + dims.push_back(GetUnknownOutputDim(node, out, d)); } - s = ic->MakeShape(dims); - ic->set_output(out,s); - }else { - VLOG(1) << "RANK still unknown." << node->name(); - continue; } + s = ic->MakeShape(dims); + ic->set_output(out, s); } + + if (!ic->RankKnown(s)) { + continue; + } + bool changed = false; std::vector dims; dims.reserve(ic->Rank(s)); diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index e94208f90f4537..69d78b710885cb 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -469,6 +469,7 @@ cc_library( "layout_util.h", "primitive_util.h", "shape.h", + "shape_dynexpr.h", "shape_partition.h", "shape_util.h", ], diff --git a/third_party/xla/xla/shape.h b/third_party/xla/xla/shape.h index 7c535c2e1e769a..1c5197ad838a88 100644 --- a/third_party/xla/xla/shape.h +++ b/third_party/xla/xla/shape.h @@ -38,261 +38,10 @@ limitations under the License. #include "xla/tsl/platform/logging.h" // IWYU pragma: keep #include "xla/util.h" #include "xla/xla_data.pb.h" +#include "xla/shape_dynexpr.h" namespace xla { -class DynExpr { - public: - virtual ~DynExpr() = default; - virtual void print(xla::Printer* printer) const = 0; - virtual void to_proto(xla::ExpressionProto* proto) const = 0; - virtual bool is_constant() const = 0; - virtual int64_t get_val() const { return -1; } - virtual DynExpr* s() = 0; // simplify - virtual DynExpr* substitute(int id, DynExpr* v) = 0; - - bool is_dynamic() { return !is_constant(); } - - static DynExpr* zero; - static DynExpr* one; - static DynExpr* _(int64_t val); - static DynExpr* V(int var_id); - static DynExpr* _s(DynExpr* expr); - static bool equal(DynExpr* expr1, DynExpr* expr2); - - friend std::ostream& operator<<(std::ostream& os, DynExpr* expr); -}; - -// constant i -class Constant : public DynExpr { - int64_t value; - - public: - explicit Constant(int64_t v) : value(v) {} - void print(xla::Printer* printer) const override { - if (value < 0) { - printer->Append("("); - } - printer->Append(value); - if (value < 0) { - printer->Append(")"); - } - } - void to_proto(xla::ExpressionProto* proto) const override { - proto->set_constant_value(value); - } - bool is_constant() const override { return true; } - int64_t get_val() const override { return value; } - DynExpr* substitute(int id, DynExpr* v) { return this; } - DynExpr* s() override; -}; - -// var id (int) -class Variable : public DynExpr { - int id; - - public: - explicit Variable(int identifier) : id(identifier) {} - void print(xla::Printer* printer) const override { - // printer->Append("(Var "); - char letter = 'A' + (id - 1); - printer->Append(std::string(1, letter)); - // printer->Append(")"); - } - void to_proto(xla::ExpressionProto* proto) const override { - proto->set_variable_id(id); - } - bool is_constant() const override { return false; } - int get_id() const { return id; } - DynExpr* substitute(int id, DynExpr* v) { return get_id() == id ? v : this;} - DynExpr* s() override; -}; - -// exp = exp + exp -class Add : public DynExpr { - DynExpr* lhs; - DynExpr* rhs; - - public: - Add(DynExpr* l, DynExpr* r) : lhs(std::move(l)), rhs(std::move(r)) {} - void print(xla::Printer* printer) const override { - printer->Append("("); - lhs->print(printer); - printer->Append(" + "); - rhs->print(printer); - printer->Append(")"); - } - - DynExpr* get_lhs() const { return lhs; } - DynExpr* get_rhs() const { return rhs; } - - void to_proto(xla::ExpressionProto* proto) const override { - auto* add_msg = proto->mutable_add_node(); - lhs->to_proto(add_msg->mutable_lhs()); - rhs->to_proto(add_msg->mutable_rhs()); - } - - bool is_constant() const override { - return lhs->is_constant() && rhs->is_constant(); - } - - int64_t get_val() const override { return lhs->get_val() + rhs->get_val(); } - - DynExpr* substitute(int id, DynExpr* v) { - return new Add(lhs->substitute(id, v), rhs->substitute(id, v)); - } - DynExpr* s() override; - - ~Add() { - delete lhs; - delete rhs; - } -}; - -// exp = exp - exp -class Sub : public DynExpr { - DynExpr* lhs; - DynExpr* rhs; - - public: - Sub(DynExpr* l, DynExpr* r) : lhs(std::move(l)), rhs(std::move(r)) {} - void print(xla::Printer* printer) const override { - printer->Append("("); - lhs->print(printer); - printer->Append(" - "); - rhs->print(printer); - printer->Append(")"); - } - - DynExpr* get_lhs() const { return lhs; } - DynExpr* get_rhs() const { return rhs; } - - void to_proto(xla::ExpressionProto* proto) const override { - auto* sub_msg = proto->mutable_sub_node(); - lhs->to_proto(sub_msg->mutable_lhs()); - rhs->to_proto(sub_msg->mutable_rhs()); - } - - bool is_constant() const override { - return lhs->is_constant() && rhs->is_constant(); - } - - int64_t get_val() const override { return lhs->get_val() - rhs->get_val(); } - - DynExpr* substitute(int id, DynExpr* v) { - return new Sub(lhs->substitute(id, v), rhs->substitute(id, v)); - } - - DynExpr* s() override; - - ~Sub() { - delete lhs; - delete rhs; - } -}; - -// exp = exp * exp -class Mul : public DynExpr { - DynExpr* lhs; - DynExpr* rhs; - - public: - Mul(DynExpr* l, DynExpr* r) : lhs(std::move(l)), rhs(std::move(r)) {} - void print(xla::Printer* printer) const override { - printer->Append("("); - lhs->print(printer); - printer->Append(" * "); - rhs->print(printer); - printer->Append(")"); - } - - DynExpr* get_lhs() const { return lhs; } - DynExpr* get_rhs() const { return rhs; } - - void to_proto(xla::ExpressionProto* proto) const override { - auto* mul_msg = proto->mutable_mul_node(); - lhs->to_proto(mul_msg->mutable_lhs()); - rhs->to_proto(mul_msg->mutable_rhs()); - } - - bool is_constant() const override { - return lhs->is_constant() && rhs->is_constant(); - } - - int64_t get_val() const override { return lhs->get_val() * rhs->get_val(); } - - DynExpr* substitute(int id, DynExpr* v) { - return new Mul(lhs->substitute(id, v), rhs->substitute(id, v)); - } - - DynExpr* s() override; - - ~Mul() { - delete lhs; - delete rhs; - } -}; - -// expr / expr -class Div : public DynExpr { - DynExpr* lhs; - DynExpr* rhs; - - public: - Div(DynExpr* l, DynExpr* r) : lhs(std::move(l)), rhs(std::move(r)) {} - void print(xla::Printer* printer) const override { - printer->Append("("); - lhs->print(printer); - printer->Append(" / ( "); - rhs->print(printer); - printer->Append(") )"); - } - - DynExpr* get_lhs() const { return lhs; } - DynExpr* get_rhs() const { return rhs; } - - void to_proto(xla::ExpressionProto* proto) const override { - auto* div_msg = proto->mutable_div_node(); - lhs->to_proto(div_msg->mutable_lhs()); - rhs->to_proto(div_msg->mutable_rhs()); - } - - bool is_constant() const override { - return lhs->is_constant() && rhs->is_constant(); - } - - int64_t get_val() const override { return lhs->get_val() / rhs->get_val(); } - - DynExpr* substitute(int id, DynExpr* v) { - return new Div(lhs->substitute(id, v), rhs->substitute(id, v)); - } - - DynExpr* s() override; - - ~Div() { - delete lhs; - delete rhs; - } -}; - -DynExpr* operator*(DynExpr& lhs, DynExpr& rhs); -DynExpr* operator*(int64_t k, DynExpr& rhs); -DynExpr* operator/(DynExpr& lhs, DynExpr& rhs); -DynExpr* operator/(DynExpr& lhs, int64_t d); -DynExpr* operator+(DynExpr& lhs, DynExpr& rhs); -DynExpr* operator+(DynExpr& lhs, int64_t d); -DynExpr* operator-(DynExpr& lhs, DynExpr& rhs); -DynExpr* operator-(DynExpr& lhs, int64_t d); -bool operator==(DynExpr& lhs, DynExpr& rhs); -bool operator==(DynExpr& lhs, int64_t d); - -inline DynExpr* DynExpr::_(int64_t val) { - if (val == 0) return DynExpr::zero; - if (val == 1) return DynExpr::one; - return new Constant(val); -} -inline DynExpr* DynExpr::V(int var_id) { return new Variable(var_id); } - // A shape describes the number of dimensions in a array, the bounds of each // dimension, and the primitive component type. For tuples, shape describes the // structure (number of elements and nesting). @@ -473,7 +222,9 @@ class Shape { bool has_dynamic_expr() const { if (auto* const state = if_array_state()) { return absl::c_any_of(state->expressions, - [](DynExpr* e) { return e->is_dynamic(); }); + [](DynExpr* e) { + return e != nullptr && e->is_dynamic(); + }); } if (auto* const state = if_tuple_state()) { return absl::c_any_of(state->tuple_shapes, [](Shape subshape) { @@ -484,7 +235,11 @@ class Shape { } DynExpr* expressions(int dimension) const { - return array_state().expressions[dimension]; + if (dimension < 0) return DynExpr::_(-999); + const auto& exprs = array_state().expressions; + const size_t dim = static_cast(dimension); + if (dim >= exprs.size()) return DynExpr::_(-999); + return exprs[dim] != nullptr ? exprs[dim] : DynExpr::_(-999); } // Returns true if the given dimension is statically-sized. diff --git a/third_party/xla/xla/shape_dynexpr.h b/third_party/xla/xla/shape_dynexpr.h new file mode 100644 index 00000000000000..5c1f5645c25e0e --- /dev/null +++ b/third_party/xla/xla/shape_dynexpr.h @@ -0,0 +1,379 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SHAPE_DYNEXPR_H_ +#define XLA_SHAPE_DYNEXPR_H_ + +#include +#include +#include + +#include "xla/printer.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +class DynExpr { + public: + virtual ~DynExpr() = default; + virtual void print(xla::Printer* printer) const = 0; + virtual void to_proto(xla::ExpressionProto* proto) const = 0; + virtual bool is_constant() const = 0; + virtual int64_t get_val() const { return -1; } + virtual DynExpr* s() = 0; // simplify + virtual DynExpr* substitute(int id, DynExpr* v) = 0; + virtual std::set get_all_ids() = 0; + virtual int64_t solve(int64_t x) = 0; + + bool is_dynamic() { return !is_constant(); } + + static DynExpr* zero; + static DynExpr* one; + static DynExpr* _(int64_t val); + static DynExpr* V(int var_id); + static DynExpr* _s(DynExpr* expr); + static bool equal(DynExpr* expr1, DynExpr* expr2); + + friend std::ostream& operator<<(std::ostream& os, DynExpr* expr); +}; + +// constant i +class Constant : public DynExpr { + int64_t value; + + public: + explicit Constant(int64_t v) : value(v) {} + void print(xla::Printer* printer) const override { + if (value < 0) { + printer->Append("("); + } + printer->Append(value); + if (value < 0) { + printer->Append(")"); + } + } + void to_proto(xla::ExpressionProto* proto) const override { + proto->set_constant_value(value); + } + bool is_constant() const override { return true; } + int64_t get_val() const override { return value; } + DynExpr* substitute(int id, DynExpr* v) { return this; } + std::set get_all_ids() { return {}; } + int64_t solve(int64_t x) { return -1; } + DynExpr* s() override; +}; + +// var id (int) +class Variable : public DynExpr { + int id; + + public: + explicit Variable(int identifier) : id(identifier) {} + void print(xla::Printer* printer) const override { + // printer->Append("(Var "); + char letter = 'A' + (id - 1); + printer->Append(std::string(1, letter)); + // printer->Append(")"); + } + void to_proto(xla::ExpressionProto* proto) const override { + proto->set_variable_id(id); + } + bool is_constant() const override { return false; } + int get_id() const { return id; } + DynExpr* substitute(int id, DynExpr* v) { return get_id() == id ? v : this;} + std::set get_all_ids() { return {get_id()}; } + int64_t solve(int64_t x) { return x; } + DynExpr* s() override; +}; + +// exp = exp + exp +class Add : public DynExpr { + DynExpr* lhs; + DynExpr* rhs; + + public: + Add(DynExpr* l, DynExpr* r) : lhs(std::move(l)), rhs(std::move(r)) {} + void print(xla::Printer* printer) const override { + printer->Append("("); + lhs->print(printer); + printer->Append(" + "); + rhs->print(printer); + printer->Append(")"); + } + + DynExpr* get_lhs() const { return lhs; } + DynExpr* get_rhs() const { return rhs; } + + void to_proto(xla::ExpressionProto* proto) const override { + auto* add_msg = proto->mutable_add_node(); + lhs->to_proto(add_msg->mutable_lhs()); + rhs->to_proto(add_msg->mutable_rhs()); + } + + bool is_constant() const override { + return lhs->is_constant() && rhs->is_constant(); + } + + int64_t get_val() const override { return lhs->get_val() + rhs->get_val(); } + + DynExpr* substitute(int id, DynExpr* v) { + return new Add(lhs->substitute(id, v), rhs->substitute(id, v)); + } + + std::set get_all_ids() { + auto s = lhs->get_all_ids(); + s.merge(rhs->get_all_ids()); + return s; + } + + int64_t solve(int64_t x) { + // Cannot solve if both lhs and rhs are dynamic... + if (lhs->is_dynamic() && rhs->is_dynamic()) return -1; + if (lhs->get_all_ids().size() == 1) { + // (A + c) = x <=> A = x - c => solve A = y with y = x - c + return lhs->solve(x - rhs->get_val()); + } + if (rhs->get_all_ids().size() == 1) { + // (c + A) = x <=> A = x - c => solve A = y with y = x - c + return rhs->solve(x - lhs->get_val()); + } + // No solution + return -1; + } + + DynExpr* s() override; + + ~Add() { + delete lhs; + delete rhs; + } +}; + +// exp = exp - exp +class Sub : public DynExpr { + DynExpr* lhs; + DynExpr* rhs; + + public: + Sub(DynExpr* l, DynExpr* r) : lhs(std::move(l)), rhs(std::move(r)) {} + void print(xla::Printer* printer) const override { + printer->Append("("); + lhs->print(printer); + printer->Append(" - "); + rhs->print(printer); + printer->Append(")"); + } + + DynExpr* get_lhs() const { return lhs; } + DynExpr* get_rhs() const { return rhs; } + + void to_proto(xla::ExpressionProto* proto) const override { + auto* sub_msg = proto->mutable_sub_node(); + lhs->to_proto(sub_msg->mutable_lhs()); + rhs->to_proto(sub_msg->mutable_rhs()); + } + + bool is_constant() const override { + return lhs->is_constant() && rhs->is_constant(); + } + + int64_t get_val() const override { return lhs->get_val() - rhs->get_val(); } + + DynExpr* substitute(int id, DynExpr* v) { + return new Sub(lhs->substitute(id, v), rhs->substitute(id, v)); + } + + std::set get_all_ids() { + auto s = lhs->get_all_ids(); + s.merge(rhs->get_all_ids()); + return s; + } + + int64_t solve(int64_t x) { + // Cannot solve if both lhs and rhs are dynamic... + if (lhs->is_dynamic() && rhs->is_dynamic()) return -1; + if (lhs->get_all_ids().size() == 1) { + // (A - c) = x <=> A = x + c => solve A = y with y = x + c + return lhs->solve(x + rhs->get_val()); + } + if (rhs->get_all_ids().size() == 1) { + // (c + A) = x <=> A = x - c => solve A = y with y = x + c + return rhs->solve(x + lhs->get_val()); + } + // No solution + return -1; + } + + DynExpr* s() override; + + ~Sub() { + delete lhs; + delete rhs; + } +}; + +// exp = exp * exp +class Mul : public DynExpr { + DynExpr* lhs; + DynExpr* rhs; + + public: + Mul(DynExpr* l, DynExpr* r) : lhs(std::move(l)), rhs(std::move(r)) {} + void print(xla::Printer* printer) const override { + printer->Append("("); + lhs->print(printer); + printer->Append(" * "); + rhs->print(printer); + printer->Append(")"); + } + + DynExpr* get_lhs() const { return lhs; } + DynExpr* get_rhs() const { return rhs; } + + void to_proto(xla::ExpressionProto* proto) const override { + auto* mul_msg = proto->mutable_mul_node(); + lhs->to_proto(mul_msg->mutable_lhs()); + rhs->to_proto(mul_msg->mutable_rhs()); + } + + bool is_constant() const override { + return lhs->is_constant() && rhs->is_constant(); + } + + int64_t get_val() const override { return lhs->get_val() * rhs->get_val(); } + + DynExpr* substitute(int id, DynExpr* v) { + return new Mul(lhs->substitute(id, v), rhs->substitute(id, v)); + } + + std::set get_all_ids() { + auto s = lhs->get_all_ids(); + s.merge(rhs->get_all_ids()); + return s; + } + + int64_t solve(int64_t x) { + // Cannot solve if both lhs and rhs are dynamic... + if (lhs->is_dynamic() && rhs->is_dynamic()) return -1; + if (lhs->get_all_ids().size() == 1) { + // (A * c) = x <=> A = x / c => solve A = y with y = x / c + int64_t c = rhs->get_val(); + if (x % c != 0) return -1; + return lhs->solve(x / c); + } + if (rhs->get_all_ids().size() == 1) { + // (c * A) = x <=> A = x / c => solve A = y with y = x / c + int64_t c = lhs->get_val(); + if (x % c != 0) return -1; + return rhs->solve(x / c); + } + // No solution + return -1; + } + + DynExpr* s() override; + + ~Mul() { + delete lhs; + delete rhs; + } +}; + +// expr / expr +class Div : public DynExpr { + DynExpr* lhs; + DynExpr* rhs; + + public: + Div(DynExpr* l, DynExpr* r) : lhs(std::move(l)), rhs(std::move(r)) {} + void print(xla::Printer* printer) const override { + printer->Append("("); + lhs->print(printer); + printer->Append(" / ( "); + rhs->print(printer); + printer->Append(") )"); + } + + DynExpr* get_lhs() const { return lhs; } + DynExpr* get_rhs() const { return rhs; } + + void to_proto(xla::ExpressionProto* proto) const override { + auto* div_msg = proto->mutable_div_node(); + lhs->to_proto(div_msg->mutable_lhs()); + rhs->to_proto(div_msg->mutable_rhs()); + } + + bool is_constant() const override { + return lhs->is_constant() && rhs->is_constant(); + } + + int64_t get_val() const override { return lhs->get_val() / rhs->get_val(); } + + DynExpr* substitute(int id, DynExpr* v) { + return new Div(lhs->substitute(id, v), rhs->substitute(id, v)); + } + + DynExpr* s() override; + + std::set get_all_ids() { + auto s = lhs->get_all_ids(); + s.merge(rhs->get_all_ids()); + return s; + } + + int64_t solve(int64_t x) { + // Cannot solve if both lhs and rhs are dynamic... + if (lhs->is_dynamic() && rhs->is_dynamic()) return -1; + if (lhs->get_all_ids().size() == 1) { + // (A / c) = x <=> A = x * c => solve A = y with y = x * c + return lhs->solve(x * rhs->get_val()); + } + if (rhs->get_all_ids().size() == 1) { + // (c / A) = x <=> A = c / x => solve A = y with y = c / x + int64_t c = lhs->get_val(); + if (c % x != 0) return -1; + return rhs->solve(c / x); + } + // No solution + return -1; + } + + ~Div() { + delete lhs; + delete rhs; + } +}; + +DynExpr* operator*(DynExpr& lhs, DynExpr& rhs); +DynExpr* operator*(int64_t k, DynExpr& rhs); +DynExpr* operator/(DynExpr& lhs, DynExpr& rhs); +DynExpr* operator/(DynExpr& lhs, int64_t d); +DynExpr* operator+(DynExpr& lhs, DynExpr& rhs); +DynExpr* operator+(DynExpr& lhs, int64_t d); +DynExpr* operator-(DynExpr& lhs, DynExpr& rhs); +DynExpr* operator-(DynExpr& lhs, int64_t d); +bool operator==(DynExpr& lhs, DynExpr& rhs); +bool operator==(DynExpr& lhs, int64_t d); + +inline DynExpr* DynExpr::_(int64_t val) { + if (val == 0) return DynExpr::zero; + if (val == 1) return DynExpr::one; + return new Constant(val); +} +inline DynExpr* DynExpr::V(int var_id) { return new Variable(var_id); } + +} // namespace xla + +#endif // XLA_SHAPE_DYNEXPR_H_ From 12879ac50f6e479628f970491db730b31918f9a2 Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Sat, 21 Mar 2026 11:35:56 +0000 Subject: [PATCH 07/16] Preserve output shape metadata across JIT boundaries Keeps inferred shape metadata alive when values cross JIT and encapsulation boundaries so downstream compilation can still recover symbolic shape information. Co-authored-by: Steven Varoumas Co-authored-by: Utku Saglam Co-authored-by: Jinyun (Joey) Ye Co-authored-by: Muteages <67578152+Muteages@users.noreply.github.com> Co-authored-by: Guillermo Callaghan --- tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc | 11 +++++++---- tensorflow/core/graph/subgraph.cc | 9 +++++++-- tensorflow/core/kernels/function_ops.cc | 9 +++++---- tensorflow/core/kernels/function_ops.h | 4 ++-- 4 files changed, 21 insertions(+), 12 deletions(-) diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 0662976897378b..65a2252e3d12a4 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -231,8 +231,8 @@ void LogExpressionsViaGraphProperties(const tensorflow::Graph& graph) { if (expr.node_type_case() == ExpressionProto::NODE_TYPE_NOT_SET) continue; - VLOG(1) << "Node " << n.name() << " has expression " - << ExprProtoToString(expr); + VLOG(1) << "Node " << n.name() << " is inferred to have expression " + << ExprProtoToString(expr) << " on dimension #" << d; auto ex = ExprFromProto(expr); exprs.push_back(std::move(ex)); @@ -620,13 +620,16 @@ absl::Status Encapsulator::Subgraph::RecordArg( << src_slot; builder.Attr("_is_batch", true); } + VLOG(1) << "Adding following output shapes for node " << src_node->name() + << " : " << tsp->DebugString(); + builder.Attr("_output_shapes", {*tsp}); } else { // if cluster argument is the real argument. - auto build_attr = attrs.FindByString("_is_batch"); + auto build_attr = attrs.FindByString("_dynamic_dim"); if (build_attr) { VLOG(1) << "Found Dynamic dimension in " << src_node->name() << ":" << src_slot; - builder.Attr("_is_batch", true); + builder.Attr("_dynamic_dim", *build_attr); } } absl::Status s = builder.Finalize(&arg_def); diff --git a/tensorflow/core/graph/subgraph.cc b/tensorflow/core/graph/subgraph.cc index 0d66e1f868c2a5..d1616e2396515e 100644 --- a/tensorflow/core/graph/subgraph.cc +++ b/tensorflow/core/graph/subgraph.cc @@ -84,9 +84,14 @@ absl::Status FeedInputs( const AttrValue* shape_attr = node_attrs.FindByString("_output_shapes"); if (shape_attr && shape_attr->has_list()) { const TensorShapeProto& shape = shape_attr->list().shape(0); - if (shape.dim_size() >= 1 && shape.dim(0).size() == -1) { - feed_node->AddAttr("_is_batch", true); + for (int i = 0; i < shape.dim_size(); ++i) { + if (shape.dim(i).size() == -1) { + feed_node->AddAttr("_dynamic_dim", i); + break; + } } + // Keep _output_shapes for further runs of shape inference + feed_node->AddAttr("_output_shapes", *shape_attr); } // Update name_index (*name_index)[feed_node->name()] = feed_node; diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index eec275448c50ce..7f131e3c23e5f4 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -44,9 +44,9 @@ ArgOp::ArgOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_)); - Status s = ctx->GetAttr("_is_batch", &is_batch_); + Status s = ctx->GetAttr("_dynamic_dim", &dynamic_dim_); if (IsNotFound(s)) { - is_batch_ = false; + dynamic_dim_ = -1; } else { OP_REQUIRES_OK(ctx, s); } @@ -78,7 +78,7 @@ void ArgOp::Compute(OpKernelContext* ctx) { OP_REQUIRES_OK(ctx, validate_type(*val)); ctx->set_output(0, *val); } - if (is_batch_) { + if (dynamic_dim_ >= 0) { BatchSizeResource* bsr = nullptr; ScopedStepContainer* step_container = ctx->step_container(); @@ -89,7 +89,8 @@ void ArgOp::Compute(OpKernelContext* ctx) { return OkStatus(); })); - const int64_t batch_size = val->dim_size(0); + const int64_t batch_size = val->dim_size(dynamic_dim_); + VLOG(1) << "Found batch_size in dimension #" << dynamic_dim_; if (bsr->GetBatchSize() == 0) { bsr->SetBatchSize(batch_size); VLOG(1) << "Set batch_size from 0 to " << batch_size diff --git a/tensorflow/core/kernels/function_ops.h b/tensorflow/core/kernels/function_ops.h index c9e8fdcd894e25..fed05ead2007ad 100644 --- a/tensorflow/core/kernels/function_ops.h +++ b/tensorflow/core/kernels/function_ops.h @@ -38,7 +38,7 @@ class ArgOp : public OpKernel { private: int index_; DataType dtype_; - bool is_batch_; + int dynamic_dim_; ArgOp(const ArgOp&) = delete; void operator=(const ArgOp&) = delete; @@ -55,7 +55,7 @@ class RetvalOp : public OpKernel { private: int index_; DataType dtype_; - bool is_batch_; + int dynamic_dim_; RetvalOp(const RetvalOp&) = delete; void operator=(const RetvalOp&) = delete; From 7ce7ee6529e1c5fafc518bb9f0bddbc99edba0b9 Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Sat, 21 Mar 2026 11:37:03 +0000 Subject: [PATCH 08/16] Harden clustering and inference for unranked shapes Tightens clustering and shape-inference behavior around partially known or unranked shapes to avoid losing dynamic-shape information in unsupported cases. Co-authored-by: Steven Varoumas Co-authored-by: Utku Saglam Co-authored-by: Jinyun (Joey) Ye Co-authored-by: Muteages <67578152+Muteages@users.noreply.github.com> Co-authored-by: Guillermo Callaghan --- .../compiler/jit/mark_for_compilation_pass.cc | 7 +- .../core/grappler/costs/graph_properties.cc | 95 ++++++++++++------- 2 files changed, 68 insertions(+), 34 deletions(-) diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 82b920e9b14019..ab11a1dabcfd59 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -820,8 +820,13 @@ void LogExpressionsViaGraphProperties(const tensorflow::Graph& graph) { const auto& tp = outs[out_idx]; const TensorShapeProto& shp = tp.shape(); - if (shp.unknown_rank()) continue; std::vector> exprs; + if (shp.unknown_rank()) { + // Add two dummy variables to represent the unknown rank + exprs.push_back(std::make_unique(-888)); + exprs.push_back(std::make_unique(-889)); + } + for (int d = 0; d < shp.dim_size(); ++d) { const auto& dim = shp.dim(d); diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 6d6a3a8568ffd8..a398158f4b9384 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -1997,30 +1997,43 @@ class SymbolicShapeRefiner { ShapeHandle s = ic->output(out); if (!ic->RankKnown(s)) { + bool recovered_rank = false; auto it = node->attr().find("_output_shapes"); - if (it == node->attr().end() || out >= it->second.list().shape_size()) { - VLOG(1) << "RANK still unknown. " << node->name(); - continue; + if(it != node->attr().end() && out < it->second.list().shape_size()){ + it = node->attr().find("shape"); } - - const TensorShapeProto& proto = it->second.list().shape(out); - if (proto.unknown_rank()) { - continue; + if (it != node->attr().end() && out < it->second.list().shape_size()) { + const TensorShapeProto& proto = it->second.list().shape(out); + if (!proto.unknown_rank()) { + std::vector dims; + dims.reserve(proto.dim_size()); + + for (int d = 0; d < proto.dim_size(); ++d) { + int64_t size = proto.dim(d).size(); + if (size >= 0) { + dims.push_back(ic->MakeDim(size)); + } else { + dims.push_back(GetUnknownOutputDim(node, out, d)); + } + } + s = ic->MakeShape(dims); + ic->set_output(out, s); + recovered_rank = true; + } } - std::vector dims; - dims.reserve(proto.dim_size()); + if (!recovered_rank && node->op() == "_Arg") { + DimensionHandle d0 = GetUnknownOutputDim(node, out, /*dim_id=*/0); + ShapeHandle vec = ic->MakeShape({d0}); + ic->set_output(out, vec); + s = vec; + recovered_rank = true; + } - for (int d = 0; d < proto.dim_size(); ++d) { - int64_t size = proto.dim(d).size(); - if (size >= 0) { - dims.push_back(ic->MakeDim(size)); - } else { - dims.push_back(GetUnknownOutputDim(node, out, d)); - } + if (!recovered_rank) { + VLOG(1) << "RANK still unknown. " << node->name(); + continue; } - s = ic->MakeShape(dims); - ic->set_output(out, s); } if (!ic->RankKnown(s)) { @@ -2247,20 +2260,16 @@ class SymbolicShapeManager { int64_t d = dims_.GetMergedValue(dim); auto* out_dim = properties->mutable_shape()->add_dim(); out_dim->set_size(d < 0 ? -1 : d); - // Serialize expression for unknown dims - if (out_dim->size() < 0) { - void* root = dims_.RootId(dim); - DynExpr* expr = nullptr; - if (auto it = dim_root_expr_.find(root); it != dim_root_expr_.end()) { - expr = it->second; - } else { - // Fallback to checking the dim handle directly - expr = GetExprFromDimHandle(dim); - } - if (expr != nullptr) { - expr->ToProto(out_dim->mutable_expr()); - // TODO: Apply simplification? - } + void* root = dims_.RootId(dim); + DimExpr* expr = nullptr; + if (auto it = dim_root_expr_.find(root); it != dim_root_expr_.end()) { + expr = it->second; + } else { + expr = ExprForDim(dim); + } + if (expr != nullptr) { + expr->ToProto(out_dim->mutable_expr()); + // TODO: Apply simplification? } } } @@ -2349,6 +2358,25 @@ class SymbolicShapeManager { return d->expr_; } + DimExpr* ExprForDim(const DimensionHandle& d) { + if (!d.IsSet()) return nullptr; + if (DimExpr* expr = GetExprFromDimHandle(d)) { + return expr; + } + if (!InferenceContext::ValueKnown(d)) { + return nullptr; + } + const int64_t value = InferenceContext::Value(d); + auto it = const_exprs_.find(value); + if (it != const_exprs_.end()) { + return it->second.get(); + } + auto expr = DimExpr::Cons(value); + DimExpr* expr_ptr = expr.get(); + const_exprs_.emplace(value, std::move(expr)); + return expr_ptr; + } + absl::Status MergeDimsWithExpr(DimensionHandle d1, DimensionHandle d2) { if (!d1.IsSet() || !d2.IsSet()) return absl::OkStatus(); @@ -2359,7 +2387,7 @@ class SymbolicShapeManager { auto get_best = [&](void* r, DimensionHandle d) -> DynExpr* { auto it = dim_root_expr_.find(r); if (it != dim_root_expr_.end()) return it->second; - return GetExprFromDimHandle(d); // may be null + return ExprForDim(d); // may be null }; DynExpr* e1 = get_best(r1, d1); @@ -2400,6 +2428,7 @@ class SymbolicShapeManager { return absl::OkStatus(); } DisjointSet shapes_; + absl::flat_hash_map> const_exprs_; // Map from union-find root pointer to the best expression for that set. absl::flat_hash_map dim_root_expr_; DisjointSet dims_; From dbb162d960c0fc506f82444c04cd380444669f70 Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Sat, 21 Mar 2026 11:37:49 +0000 Subject: [PATCH 09/16] Refine tf2xla shape propagation and constant padding Improves tf2xla shape propagation and normalizes padded constants so shape-derived values can continue through lowering in a consistent way. Co-authored-by: Steven Varoumas Co-authored-by: Utku Saglam Co-authored-by: Jinyun (Joey) Ye Co-authored-by: Muteages <67578152+Muteages@users.noreply.github.com> Co-authored-by: Guillermo Callaghan --- .../jit/encapsulate_subgraphs_pass.cc | 169 +++++------------- .../compiler/jit/mark_for_compilation_pass.cc | 12 +- 2 files changed, 50 insertions(+), 131 deletions(-) diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 65a2252e3d12a4..2ac1b118d3f2ac 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/types/optional.h" +#include "tensorflow/compiler/jit/encapsulate_util.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/jit/shape_inference_helpers.h" @@ -55,8 +56,6 @@ limitations under the License. #include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/dump_graph.h" #include "tensorflow/core/framework/tensor_shape.pb.h" -#include "tensorflow/core/grappler/costs/graph_properties.h" -#include "tensorflow/core/grappler/grappler_item.h" namespace tensorflow { static const absl::flat_hash_set kFailingOps = { @@ -145,109 +144,6 @@ std::string ExprProtoToString(const ExpressionProto& e) { } } -std::map>> test_map; - -std::unique_ptr ExprFromProto(const ExpressionProto& proto) { - switch (proto.node_type_case()) { - case ExpressionProto::kConstantValue: - return DynExpr::Cons(proto.constant_value()); - case ExpressionProto::kVariableId: - return DynExpr::Var(proto.variable_id()); - case ExpressionProto::kAddNode: { - auto lhs = ExprFromProto(proto.add_node().lhs()); - auto rhs = ExprFromProto(proto.add_node().rhs()); - // Note: These are owning pointers, but ExprAdd takes raw pointers. - // The caller must manage lifetime appropriately. - return std::make_unique(lhs.release(), rhs.release()); - } - case ExpressionProto::kSubNode: { - auto lhs = ExprFromProto(proto.sub_node().lhs()); - auto rhs = ExprFromProto(proto.sub_node().rhs()); - return std::make_unique(lhs.release(), rhs.release()); - } - case ExpressionProto::kMulNode: { - auto lhs = ExprFromProto(proto.mul_node().lhs()); - auto rhs = ExprFromProto(proto.mul_node().rhs()); - return std::make_unique(lhs.release(), rhs.release()); - } - case ExpressionProto::kDivNode: { - auto lhs = ExprFromProto(proto.div_node().lhs()); - auto rhs = ExprFromProto(proto.div_node().rhs()); - return std::make_unique(lhs.release(), rhs.release()); - } - case ExpressionProto::NODE_TYPE_NOT_SET: - default: - return nullptr; - } -} - -// Runs Grappler static inference and logs any ExpressionProto found in output -// tensor shapes (from GraphProperties, not from _output_shapes attrs). -void LogExpressionsViaGraphProperties(const tensorflow::Graph& graph) { - using tensorflow::ExpressionProto; - using tensorflow::GraphDef; - using tensorflow::NodeDef; - using tensorflow::TensorShapeProto; - using tensorflow::grappler::GraphProperties; - using tensorflow::grappler::GrapplerItem; - - GraphDef graph_def; - graph.ToGraphDef(&graph_def); - - GrapplerItem item; - item.id = "mark_for_compilation_pass_expr_dump"; - item.graph = graph_def; - - GraphProperties props(item); - - absl::Status st = props.InferStatically( - /*assume_valid_feeds=*/false, - /*aggressive_shape_inference=*/false, - /*include_input_tensor_values=*/false, - /*include_output_tensor_values=*/false); - - if (!st.ok()) { - LOG(ERROR) << "[EXPR][GP] InferStatically failed: " << st.message(); - return; - } - - int found = 0; - VLOG(1) << "[EXPR][GP] === GraphProperties output expr dump ==="; - - - for (const NodeDef& n : graph_def.node()) { - if (!props.HasOutputProperties(n.name())) continue; - const auto& outs = props.GetOutputProperties(n.name()); - for (int out_idx = 0; out_idx < static_cast(outs.size()); ++out_idx) { - const auto& tp = outs[out_idx]; - const TensorShapeProto& shp = tp.shape(); - - if (shp.unknown_rank()) continue; - std::vector> exprs; - for (int d = 0; d < shp.dim_size(); ++d) { - const auto& dim = shp.dim(d); - - const ExpressionProto& expr = dim.expr(); - if (expr.node_type_case() == ExpressionProto::NODE_TYPE_NOT_SET) - continue; - - VLOG(1) << "Node " << n.name() << " is inferred to have expression " - << ExprProtoToString(expr) << " on dimension #" << d; - - auto ex = ExprFromProto(expr); - exprs.push_back(std::move(ex)); - - ++found; - } - test_map[n.name()] = std::move(exprs); - } - } - - VLOG(1) << "[EXPR][GP] === Found " << found - << " expressions via GraphProperties ==="; -} - - struct OutputInputTensorPairHasher { uint64 operator()(std::pair const& s) const { return Hash64Combine(OutputTensor::Hash()(s.first), @@ -503,6 +399,19 @@ class Encapsulator { namespace { +bool BuildOutputShapeProto(const Node& node, int output_slot, + TensorShapeProto* proto) { + AttrSlice attrs = node.attrs(); + auto shape_attr = + attrs.FindByString(kXlaInferredOutputTensorShapesAttrName); + if (shape_attr == nullptr || !shape_attr->has_list() || + shape_attr->list().shape_size() <= output_slot) { + return false; + } + *proto = shape_attr->list().shape(output_slot); + return true; +} + // Return in 'sorted' a topological sort of clusters according to the // dependencies encoded in ancestors. clusters is the list of all clusters // including clusters that are not present in the ancestors map. has_successors @@ -585,6 +494,31 @@ Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) { Graph* Encapsulator::Subgraph::GetGraph() const { return graph_.get(); } +void ExprToProto(xla::DynExpr* expr, ExpressionProto* proto) { + auto e = expr->s(); + if (xla::Constant* c = dynamic_cast(e)) { + proto->set_constant_value(c->get_val()); + } else if (xla::Variable* v = dynamic_cast(e)) { + proto->set_variable_id(v->get_id()); + } else if (xla::Add* a = dynamic_cast(e)) { + auto* add_msg = proto->mutable_add_node(); + ExprToProto(a->get_lhs(), add_msg->mutable_lhs()); + ExprToProto(a->get_rhs(), add_msg->mutable_rhs()); + } else if (xla::Mul* m = dynamic_cast(e)) { + auto* mul_msg = proto->mutable_mul_node(); + ExprToProto(m->get_lhs(), mul_msg->mutable_lhs()); + ExprToProto(m->get_rhs(), mul_msg->mutable_rhs()); + } else if (xla::Sub* s = dynamic_cast(e)) { + auto* sub_msg = proto->mutable_sub_node(); + ExprToProto(s->get_lhs(), sub_msg->mutable_lhs()); + ExprToProto(s->get_rhs(), sub_msg->mutable_rhs()); + } else if (xla::Div* d = dynamic_cast(e)) { + auto* div_msg = proto->mutable_div_node(); + ExprToProto(d->get_lhs(), div_msg->mutable_lhs()); + ExprToProto(d->get_rhs(), div_msg->mutable_rhs()); + } +} + absl::Status Encapsulator::Subgraph::RecordArg( const Edge* edge, const absl::flat_hash_map& node_images, @@ -605,24 +539,12 @@ absl::Status Encapsulator::Subgraph::RecordArg( builder.Attr("T", dtype); builder.Attr("index", arg_index); AttrSlice attrs = src_node->attrs(); - auto shape_attr = attrs.FindByString("_output_shapes"); - if (shape_attr && shape_attr->has_list()) { - const TensorShapeProto& shape = shape_attr->list().shape(src_slot); - std::vector> expressions = - std::move(test_map[src_node->name()]); - for (const auto& e : expressions) { - if (!e->IsConstant()) { - builder.Attr("_is_batch", true); - } - } - if (shape.dim_size() >= 1 && shape.dim(0).size() == -1) { - VLOG(1) << "Found Dynamic dimension in " << src_node->name() << ":" - << src_slot; - builder.Attr("_is_batch", true); - } + TensorShapeProto output_shape_proto; + if (BuildOutputShapeProto(*src_node, src_slot, &output_shape_proto)) { VLOG(1) << "Adding following output shapes for node " << src_node->name() - << " : " << tsp->DebugString(); - builder.Attr("_output_shapes", {*tsp}); + << " : " << output_shape_proto.DebugString(); + builder.Attr("_output_shapes", {output_shape_proto}); + builder.Attr(kXlaInferredOutputShapesAttrName, {output_shape_proto}); } else { // if cluster argument is the real argument. auto build_attr = attrs.FindByString("_dynamic_dim"); @@ -1321,9 +1243,6 @@ absl::Status EncapsulateSubgraphsPass::Run( options.flib_def); } - LogExpressionsViaGraphProperties(**options.graph); - - // TODO(b/195757077): Remove this once there is a better way to disable // GraphOptimizationPasses that are not needed due to MLIR bridge. for (Node* n : (*options.graph)->nodes()) { diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index ab11a1dabcfd59..2d6ee4bb3f7c60 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -821,12 +821,6 @@ void LogExpressionsViaGraphProperties(const tensorflow::Graph& graph) { const TensorShapeProto& shp = tp.shape(); std::vector> exprs; - if (shp.unknown_rank()) { - // Add two dummy variables to represent the unknown rank - exprs.push_back(std::make_unique(-888)); - exprs.push_back(std::make_unique(-889)); - } - for (int d = 0; d < shp.dim_size(); ++d) { const auto& dim = shp.dim(d); @@ -842,6 +836,12 @@ void LogExpressionsViaGraphProperties(const tensorflow::Graph& graph) { ++found; } + if (shp.dim_size() == 0 && shp.unknown_rank()) { + // Add two dummy variables to represent the unknown rank + exprs.push_back(std::make_unique(-888)); + exprs.push_back(std::make_unique(-889)); + } + list_exprs[out_idx] = std::move(exprs); } expr_map[n.name()] = std::move(list_exprs); From 13fba41a39df480a38824b291e45c61d77706b41 Mon Sep 17 00:00:00 2001 From: utku-work Date: Thu, 19 Mar 2026 17:12:15 +0000 Subject: [PATCH 10/16] Reshape ratio issue. Fixes the reshape ratio handling so shape-expression information stays consistent when reshape dimensions are derived from symbolic values. Co-authored-by: Steven Varoumas Co-authored-by: Utku Saglam Co-authored-by: Jinyun (Joey) Ye Co-authored-by: Muteages <67578152+Muteages@users.noreply.github.com> Co-authored-by: Guillermo Callaghan --- .../compiler/tf2xla/kernels/reshape_op.cc | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index cadbafe2d27fbd..c5475d1e66485a 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -71,6 +71,7 @@ class ReshapeOp : public XlaOpKernel { unknown_index = d; shape.AddDim(1); shape.AddExpression(xla::DynExpr::one); + ratio = 1; } else if (size == 0) { // We don't include zero-sized dimension in product, so that we can // still calculate number of elements for non-zero-sized dimensions and @@ -102,15 +103,17 @@ class ReshapeOp : public XlaOpKernel { size_expr = new_expr->s(); } else { - if (ratio == 1){ // Nothing has been previously split. - size_expr = xla::DynExpr::_(size); - } else if (ratio == size) { // The factor of the previous split is - // the new dimension. - size_expr = xla::DynExpr::_(size); - ratio = 1; // reset ratio - } else { - // Should not happen. - size_expr = xla::DynExpr::_(-50); + size_expr = xla::DynExpr::_(size); + if (ratio != 1) { + // A split dynamic dimension can be materialized by multiple later + // known dimensions. Any unresolved remainder is kept in `ratio` + // and may be consumed by a subsequent `-1` dimension (if present); + // otherwise, it remains unapplied. + if (ratio % size == 0) { + ratio /= size; + } else if (size % ratio == 0) { + ratio = 1; + } } } shape.AddExpression(size_expr); From ccb7e0c314fa31bfdba8fc304f13bc70fed4cee1 Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Fri, 20 Mar 2026 11:50:32 +0000 Subject: [PATCH 11/16] Use Eigen only for dynamic CPU dots Limits the Eigen fallback path to dynamic CPU dot cases so static cases continue to use the normal lowering while dynamic dimensions remain supported. Co-authored-by: Steven Varoumas Co-authored-by: Utku Saglam Co-authored-by: Jinyun (Joey) Ye Co-authored-by: Muteages <67578152+Muteages@users.noreply.github.com> Co-authored-by: Guillermo Callaghan --- .../xla/xla/service/cpu/dot_op_emitter.cc | 44 ++++++++++++++++--- 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/service/cpu/dot_op_emitter.cc b/third_party/xla/xla/service/cpu/dot_op_emitter.cc index 5ec762ae5cd56c..e201b87a4a2a6a 100644 --- a/third_party/xla/xla/service/cpu/dot_op_emitter.cc +++ b/third_party/xla/xla/service/cpu/dot_op_emitter.cc @@ -160,6 +160,27 @@ bool CanEmitTiledLlvmIrGemm( return true; } +bool HasDynamicMatmulDims(const DotInfo& dot_info) { + const Shape& lhs_shape = dot_info.lhs_shape; + const Shape& rhs_shape = dot_info.rhs_shape; + const DotDimensionNumbers& dim_nums = dot_info.dim_nums; + + DynExpr* m_expr = lhs_shape.dimensions().size() <= 1 + ? DynExpr::one + : lhs_shape.expressions( + 1LL - dim_nums.lhs_contracting_dimensions(0)); + DynExpr* k_expr = + lhs_shape.expressions(dim_nums.lhs_contracting_dimensions(0)); + DynExpr* n_expr = rhs_shape.dimensions().size() <= 1 + ? DynExpr::one + : rhs_shape.expressions( + 1LL - dim_nums.rhs_contracting_dimensions(0)); + + return (m_expr != nullptr && m_expr->is_dynamic()) || + (k_expr != nullptr && k_expr->is_dynamic()) || + (n_expr != nullptr && n_expr->is_dynamic()); +} + // Returns dot implementation strategy for non-batch dot operations. DotImplementationStrategy GetNonBatchDotImplementationStrategy( const HloModuleConfig& config, const DotInfo& dot_info, @@ -167,15 +188,16 @@ DotImplementationStrategy GetNonBatchDotImplementationStrategy( bool allow_runtime_calls) { PrimitiveType element_type = dot_info.result_shape.element_type(); - // Force Eigen all the time. - return DotImplementationStrategy::kEigen; - // Batched dot either handled by a runtime call or expanded into a sequence // of non-batch dot operations. DCHECK(dot_info.dim_nums.lhs_batch_dimensions_size() == 0 && dot_info.dim_nums.rhs_batch_dimensions_size() == 0) << "Dot operations must be non-batch"; + if (HasDynamicMatmulDims(dot_info)) { + return DotImplementationStrategy::kEigen; + } + // Any Matrix-Vector product of floating point or integral type, or // a transpose-dot fusion of the same can be lowered to a tiled LLVM // IR implementation. @@ -955,18 +977,28 @@ absl::Status DotOpEmitter::EmitCallToBatchRuntime() { if (!mat_mult_dims.lhs_column_major) { std::swap(mat_mult_dims.m, mat_mult_dims.n); + std::swap(mat_mult_dims.m_expr, mat_mult_dims.n_expr); std::swap(lhs, rhs); std::swap(transpose_lhs, transpose_rhs); } + llvm::Value* m_val = xla::llvm_ir::EmitExpression(b_, mat_mult_dims.m_expr); + llvm::Value* n_val = xla::llvm_ir::EmitExpression(b_, mat_mult_dims.n_expr); + llvm::Value* k_val = xla::llvm_ir::EmitExpression(b_, mat_mult_dims.k_expr); + DynExpr* batch_size_expr = lhs_shape.expressions(0); + if (batch_size_expr == nullptr) { + batch_size_expr = DynExpr::_(lhs_shape.dimensions(0)); + } + llvm::Value* batch_size_val = + xla::llvm_ir::EmitExpression(b_, batch_size_expr); + VLOG(1) << "Batch dot emitted with runtime:" << fn_name; b_->CreateCall( matmul_func, {executable_run_options_value_, target_array_.GetBasePointer(), - lhs->GetBasePointer(), rhs->GetBasePointer(), - b_->getInt64(mat_mult_dims.m), b_->getInt64(mat_mult_dims.n), - b_->getInt64(mat_mult_dims.k), b_->getInt64(lhs_shape.dimensions(0)), + lhs->GetBasePointer(), rhs->GetBasePointer(), m_val, n_val, k_val, + batch_size_val, b_->getInt32(static_cast(transpose_lhs)), b_->getInt32(static_cast(transpose_rhs))}); return absl::OkStatus(); From 480d09c4263db68da49c94dafa87ece778351eb8 Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Fri, 20 Mar 2026 13:19:59 +0000 Subject: [PATCH 12/16] Share expression inference across passes Shares common expression-inference helpers across passes so symbolic dimension analysis stays aligned in different parts of the pipeline. Co-authored-by: Steven Varoumas Co-authored-by: Utku Saglam Co-authored-by: Jinyun (Joey) Ye Co-authored-by: Muteages <67578152+Muteages@users.noreply.github.com> Co-authored-by: Guillermo Callaghan --- tensorflow/compiler/jit/encapsulate_util.cc | 4 +++ tensorflow/compiler/jit/encapsulate_util.h | 9 ++++++ .../compiler/jit/mark_for_compilation_pass.cc | 30 +++++++++++++++++-- 3 files changed, 41 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/jit/encapsulate_util.cc b/tensorflow/compiler/jit/encapsulate_util.cc index fa94a341bbabc6..e0849dcf192c81 100644 --- a/tensorflow/compiler/jit/encapsulate_util.cc +++ b/tensorflow/compiler/jit/encapsulate_util.cc @@ -303,6 +303,10 @@ absl::Status PostprocessControlEdgesBetweenOutsideCompilations( } // namespace const char kXlaInferredShapesAttrName[] = "_xla_inferred_shapes"; +const char kXlaInferredOutputTensorShapesAttrName[] = + "_xla_inferred_output_tensor_shapes"; +const char kXlaInferredOutputShapesAttrName[] = + "_xla_inferred_output_shapes"; const char kXlaConnectedToXlaComputationAttrName[] = "_xla_connected_to_xla_computation"; diff --git a/tensorflow/compiler/jit/encapsulate_util.h b/tensorflow/compiler/jit/encapsulate_util.h index 7c99763c770728..25ca891529cadc 100644 --- a/tensorflow/compiler/jit/encapsulate_util.h +++ b/tensorflow/compiler/jit/encapsulate_util.h @@ -28,6 +28,15 @@ namespace tensorflow { // a list of PartialTensorShape objects. extern const char kXlaInferredShapesAttrName[]; +// Attribute marking Grappler-inferred output TensorShapeProtos. Attribute +// value is a list of TensorShapeProto objects and may include ExpressionProto +// annotations when available. +extern const char kXlaInferredOutputTensorShapesAttrName[]; + +// Attribute carrying inferred output TensorShapeProtos on encapsulated _Arg +// nodes for XlaCompileOp argument reconstruction. +extern const char kXlaInferredOutputShapesAttrName[]; + // Infers output shapes for all nodes in graph `g`. The output shapes will be // stored in node attribute `kXlaInferredShapesAttrName`. // diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 2d6ee4bb3f7c60..f4d64c7db2481f 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -41,6 +41,7 @@ limitations under the License. #include "tensorflow/compiler/jit/deadness_analysis.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/device_util.h" +#include "tensorflow/compiler/jit/encapsulate_util.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" @@ -780,7 +781,7 @@ static xla::DynExpr* DimExprToDynExpr(const DimExpr* e) { // Runs Grappler static inference and logs any ExpressionProto found in output // tensor shapes (from GraphProperties, not from _output_shapes attrs). -void LogExpressionsViaGraphProperties(const tensorflow::Graph& graph) { +void LogExpressionsViaGraphProperties(tensorflow::Graph& graph) { using tensorflow::ExpressionProto; using tensorflow::GraphDef; using tensorflow::NodeDef; @@ -790,6 +791,7 @@ void LogExpressionsViaGraphProperties(const tensorflow::Graph& graph) { GraphDef graph_def; graph.ToGraphDef(&graph_def); + auto node_name_index = graph.BuildNodeNameIndex(); GrapplerItem item; item.id = "mark_for_compilation_pass_expr_dump"; @@ -811,14 +813,31 @@ void LogExpressionsViaGraphProperties(const tensorflow::Graph& graph) { int found = 0; VLOG(1) << "[EXPR][GP] === GraphProperties output expr dump ==="; + auto convert_graph_properties_shape = [](const TensorShapeProto& gp_shape) { + TensorShapeProto out; + out.set_unknown_rank(gp_shape.unknown_rank()); + for (const auto& dim : gp_shape.dim()) { + out.add_dim()->set_size(dim.size()); + ExpressionProto* expr = out.add_expressions(); + if (dim.expr().node_type_case() != ExpressionProto::NODE_TYPE_NOT_SET) { + *expr = dim.expr(); + } else { + expr->set_constant_value(dim.size()); + } + } + return out; + }; for (const NodeDef& n : graph_def.node()) { if (!props.HasOutputProperties(n.name())) continue; const auto& outs = props.GetOutputProperties(n.name()); + std::vector inferred_output_shapes; + inferred_output_shapes.reserve(outs.size()); std::vector>> list_exprs(outs.size()); for (int out_idx = 0; out_idx < static_cast(outs.size()); ++out_idx) { const auto& tp = outs[out_idx]; const TensorShapeProto& shp = tp.shape(); + inferred_output_shapes.push_back(convert_graph_properties_shape(shp)); std::vector> exprs; for (int d = 0; d < shp.dim_size(); ++d) { @@ -845,6 +864,11 @@ void LogExpressionsViaGraphProperties(const tensorflow::Graph& graph) { list_exprs[out_idx] = std::move(exprs); } expr_map[n.name()] = std::move(list_exprs); + auto node_it = node_name_index.find(n.name()); + if (node_it != node_name_index.end()) { + node_it->second->AddAttr(kXlaInferredOutputTensorShapesAttrName, + inferred_output_shapes); + } } @@ -892,8 +916,10 @@ absl::StatusOr MarkForCompilationPassImpl::Initialize() { if (debug_options_.annotate_cluster_id) { TF_RETURN_IF_ERROR(AssignAnnotatedClusterIDs()); } - if (debug_options_.cluster_single_dynamic_dim) { + if (debug_options_.enable_dynamic_sizes) { LogExpressionsViaGraphProperties(*graph_); + } + if (debug_options_.cluster_single_dynamic_dim) { TF_RETURN_IF_ERROR(AssignDimVars()); } if (debug_options_.enable_cluster_parallel) { From 509ca5f4335360d7c46ab58991fda68782a4fef9 Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Fri, 20 Mar 2026 13:43:20 +0000 Subject: [PATCH 13/16] Mirror the behaviour of dimension changes in expressions for more TF operators Extends expression updates across additional TensorFlow operators so symbolic dimensions continue to reflect the same shape transformations as concrete dimensions. Co-authored-by: Steven Varoumas Co-authored-by: Utku Saglam Co-authored-by: Jinyun (Joey) Ye Co-authored-by: Muteages <67578152+Muteages@users.noreply.github.com> Co-authored-by: Guillermo Callaghan --- .../compiler/tf2xla/kernels/reshape_op.cc | 25 ++++++++++++------- .../tf2xla/kernels/reverse_sequence_op.cc | 10 ++++++-- .../compiler/tf2xla/kernels/shape_op.cc | 5 +++- .../compiler/tf2xla/kernels/slice_op.cc | 14 +++++++++-- .../compiler/tf2xla/kernels/split_op.cc | 17 +++++++++---- tensorflow/core/util/strided_slice_op.cc | 12 ++++++--- 6 files changed, 60 insertions(+), 23 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index c5475d1e66485a..3ac025d0408f41 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -164,9 +164,12 @@ class ReshapeOp : public XlaOpKernel { input, xla::Zero(ctx->builder(), input_xla_shape->element_type()), 0, 0, padded_input_num - input_num_elements); input_shape.set_dim(0, padded_input_num); - input_shape.set_expression( - 0, xla::DynExpr::_( - padded_input_num)); // Issue here as it depends on ceil + // This expression only approximates the padded size: the true value + // uses ceil(input_num_elements / product) * product, which we do not + // model symbolically here. + xla::DynExpr* padded_input_num_expr = + (*(*input_num_elements_expr / *product_expr) * *product_expr)->s(); + input_shape.set_expression(0, padded_input_num_expr); } } shape.set_dim(unknown_index, missing); @@ -190,14 +193,14 @@ class ReshapeOp : public XlaOpKernel { std::vector output_dim_sizes; std::vector dims_are_dynamic; - std::vector expressions; + std::vector output_dim_exprs; const auto& dims = shape.dims(); dims_are_dynamic.reserve(dims); output_dim_sizes.reserve(dims); for (int64_t i = 0; i < dims; ++i) { output_dim_sizes.push_back( xla::Reshape(xla::Slice(ctx->Input(1), {i}, {i + 1}, {1}), {})); - expressions.push_back(xla::DynExpr::_(-10)); + output_dim_exprs.push_back(xla::DynExpr::_(-111)); } OP_REQUIRES_OK( ctx, ctx->ResolveInputDynamismIntoPredVector(1, &dims_are_dynamic)); @@ -205,7 +208,7 @@ class ReshapeOp : public XlaOpKernel { // No unknown index. ctx->SetOutput( 0, xla::DynamicReshape(input, output_dim_sizes, shape.dim_sizes(), - dims_are_dynamic, expressions)); + dims_are_dynamic, output_dim_exprs)); return; } auto common_factors = @@ -216,26 +219,30 @@ class ReshapeOp : public XlaOpKernel { auto start = common_factors[i]; auto end = common_factors[i + 1]; bool input_is_dynamic = false; - xla::DynExpr* expression = xla::DynExpr::_(-20); // product of all input dims in this group. E.g., in // reshape(Tensor([2, 3, 3]), [3, -1, 3]) product of the group // containing -1 will be 6. xla::XlaOp product = xla::One(ctx->builder(), xla::S32); + xla::DynExpr* product_expr = xla::DynExpr::one; for (int64_t dim = start.first; dim < end.first; ++dim) { if (input_xla_shape->is_dynamic_dimension(dim)) { input_is_dynamic = true; } product = xla::Mul(product, xla::GetDimensionSize(input, dim)); + product_expr = (*product_expr * *input_shape.get_expression(dim))->s(); } bool unknown_dim_in_group = false; // The real size for the -1 dimension in a reshape. E.g., in // reshape(Tensor([2, 3, 3]), [3, -1, 3]) this will be 2. xla::XlaOp unknown_dim_size = product; + xla::DynExpr* unknown_dim_expr = product_expr; for (int64_t dim = start.second; dim < end.second; ++dim) { if (dim == unknown_index) { unknown_dim_in_group = true; } else { unknown_dim_size = xla::Div(unknown_dim_size, output_dim_sizes[dim]); + unknown_dim_expr = + (*unknown_dim_expr / *output_dim_exprs[dim])->s(); } } @@ -243,13 +250,13 @@ class ReshapeOp : public XlaOpKernel { // If input dim is dynamic, output dim at the -1 position must be // dynamic. Similarly, if input dim is static, output dim has to be // static at the -1 dimension. - expressions[unknown_index] = expression; + output_dim_exprs[unknown_index] = unknown_dim_expr; dims_are_dynamic[unknown_index] = input_is_dynamic; output_dim_sizes[unknown_index] = unknown_dim_size; ctx->SetOutput( 0, xla::DynamicReshape(input, output_dim_sizes, shape.dim_sizes(), - dims_are_dynamic, expressions)); + dims_are_dynamic, output_dim_exprs)); VLOG(2) << "Reshape from " << ctx->InputXlaShape(0)->ToString() << " to " << xla::VectorString(shape.dim_sizes()) << ", dynamic_dims=" << xla::VectorString(dims_are_dynamic); diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc index 5cecbf37706283..0515dd3b785cfb 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc @@ -86,11 +86,17 @@ class ReverseSequenceOp : public XlaOpKernel { xla::XlaOp back = xla::Sub(seq_lens, xla::ScalarLike(seq_lens, 1)); xla::XlaOp batch_idx = xla::Iota( builder, - xla::ShapeUtil::MakeShape(seq_lens_type, {batch_size, max_seq_len, 1}), + xla::ShapeUtil::MakeShape(seq_lens_type, {batch_size, max_seq_len, 1}, + {input_shape.get_expression(batch_dim_), + input_shape.get_expression(seq_dim_), + xla::DynExpr::one}), /*iota_dimension=*/0); xla::XlaOp forward_idx = xla::Iota( builder, - xla::ShapeUtil::MakeShape(seq_lens_type, {batch_size, max_seq_len, 1}), + xla::ShapeUtil::MakeShape(seq_lens_type, {batch_size, max_seq_len, 1}, + {input_shape.get_expression(batch_dim_), + input_shape.get_expression(seq_dim_), + xla::DynExpr::one}), /*iota_dimension=*/1); xla::XlaOp reverse_idx = xla::Sub(back, forward_idx, {0}); reverse_idx = xla::Select(xla::Lt(reverse_idx, xla::ZerosLike(reverse_idx)), diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index d376dfb9240876..5d37b2f4283cf8 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -340,6 +340,7 @@ class SqueezeOp : public XlaOpKernel { absl::flat_hash_set wrapped_squeeze_dims; wrapped_squeeze_dims.reserve(squeeze_dims_.size()); std::vector new_shape; + std::vector new_exprs; // Validate squeeze dims against the input. for (int32_t dim : squeeze_dims_) { OP_REQUIRES( @@ -367,6 +368,7 @@ class SqueezeOp : public XlaOpKernel { } else { // This dimension is not being squeezed. new_shape.push_back(existing_dim); + new_exprs.push_back(shape.expressions(i)); } } else { OP_REQUIRES( @@ -377,11 +379,12 @@ class SqueezeOp : public XlaOpKernel { // Copy over all non-1-length dimensions. if (existing_dim != 1) { new_shape.push_back(existing_dim); + new_exprs.push_back(shape.expressions(i)); } } } - ctx->SetOutput(0, xla::Reshape(ctx->Input(0), new_shape)); + ctx->SetOutput(0, xla::Reshape(ctx->Input(0), new_shape, new_exprs)); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index ef93f16c78cf42..caffe874e9781c 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -66,13 +66,17 @@ class SliceOp : public XlaOpKernel { ctx->ConstantInputAsIntVector(2, &size).ok(); if (all_begins_are_constant && all_sizes_are_constant) { std::vector wrapped_size(size.size()); + std::vector wrapped_size_exprs(size.size()); // `begin` is a compile-time constant. for (int i = 0; i < input_dims; ++i) { if (size[i] == -1) { // A size[i] of -1 means "all elements from begin[i] to dim_size(i)". wrapped_size[i] = input_shape.dim_size(i) - begin[i]; + wrapped_size_exprs[i] = + (*input_shape.get_expression(i) - begin[i])->s(); } else { wrapped_size[i] = size[i]; + wrapped_size_exprs[i] = xla::DynExpr::_(size[i]); } } @@ -107,7 +111,7 @@ class SliceOp : public XlaOpKernel { exprs.reserve(begin.size()); for (int i = 0; i < begin.size(); ++i) { limits.push_back(begin[i] + wrapped_size[i]); - exprs.push_back(xla::DynExpr::_(begin[i] + wrapped_size[i])); + exprs.push_back((*begin_exprs[i] + *wrapped_size_exprs[i])->s()); } std::vector strides(begin.size(), 1); auto slice = @@ -163,7 +167,13 @@ class SliceOp : public XlaOpKernel { } if (all_sizes_are_constant && !constant_size_is_minus_one) { xla::XlaOp input = ctx->Input(0); - ctx->SetOutput(0, xla::DynamicSlice(input, begin_indices, size)); + std::vector output_exprs; + output_exprs.reserve(size.size()); + for (int64_t d : size) { + output_exprs.push_back(xla::DynExpr::_(d)); + } + ctx->SetOutput( + 0, xla::DynamicSlice(input, begin_indices, size, output_exprs)); } else { // Size is not constant, use input size as upperbound and then set // dimension size on it. diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index 482438cedfb42b..bf0c294e21f003 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -211,21 +211,28 @@ class SplitVOp : public XlaOpKernel { input_shape.dim_size(split_dim) - total_split_size; } - // The vectors we will use to define the slice. The entry for the - // split dimensions varies for each output. + // The vectors we will use to define the slice. The entry for the split + // dimension varies for each output. std::vector begin(input_shape.dims(), 0); auto dim_sizes = input_shape.dim_sizes(); std::vector limits(dim_sizes.begin(), dim_sizes.end()); std::vector strides(input_shape.dims(), 1); + std::vector begin_expr(input_shape.dims(), + xla::DynExpr::zero); + auto input_exprs = input_shape.get_expressions(); + std::vector limits_expr(input_exprs.begin(), + input_exprs.end()); for (int i = 0; i < num_split; ++i) { - TensorShape output_shape(input_shape); int slice_size = split_sizes[i]; - output_shape.set_dim(split_dim, slice_size); + xla::DynExpr* slice_expr = xla::DynExpr::_(slice_size); // Slice out the ith split from the split dimension. limits[split_dim] = begin[split_dim] + slice_size; - ctx->SetOutput(i, xla::Slice(input, begin, limits, strides)); + limits_expr[split_dim] = (*begin_expr[split_dim] + *slice_expr)->s(); + ctx->SetOutput( + i, xla::Slice(input, begin, limits, begin_expr, limits_expr, strides)); begin[split_dim] = limits[split_dim]; + begin_expr[split_dim] = limits_expr[split_dim]; } } }; diff --git a/tensorflow/core/util/strided_slice_op.cc b/tensorflow/core/util/strided_slice_op.cc index 26712412a54b1e..29c5d56efec504 100644 --- a/tensorflow/core/util/strided_slice_op.cc +++ b/tensorflow/core/util/strided_slice_op.cc @@ -346,6 +346,8 @@ absl::Status ValidateStridedSliceOp( bool shrink_i = (dense_spec.shrink_axis_mask & (1 << i)); if (dim_i == -1) { processing_shape->AddDim(shrink_i ? 1 : -1); + processing_shape->AddExpression(shrink_i ? xla::DynExpr::_(1) + : xla::DynExpr::_(-1)); continue; } @@ -415,13 +417,15 @@ absl::Status ValidateStridedSliceOp( "slice index ", begin_i, " of dimension ", i, " out of bounds."); } } else { - begin_i = canonical(begin_i, 0); - end_i = canonical(end_i, 1); + const int64_t begin_raw = begin_i; + const int64_t end_raw = end_i; + begin_i = canonical(begin_raw, 0); + end_i = canonical(end_raw, 1); if (begin_expr) { - (*begin_expr)[i] = canonical_expr(begin_i, 0)->s(); + (*begin_expr)[i] = canonical_expr(begin_raw, 0)->s(); } if (end_expr) { - (*end_expr)[i] = canonical_expr(end_i, 1)->s(); + (*end_expr)[i] = canonical_expr(end_raw, 1)->s(); } } // Update optimization values From 805fdfdb02ae500ffd153a376269b64ecf11e2bd Mon Sep 17 00:00:00 2001 From: "Jinyun (Joey) Ye" Date: Fri, 20 Mar 2026 16:00:51 +0000 Subject: [PATCH 14/16] Remove tf_xla_cluster_single_dynamic_dim and reuse tf_xla_enable_dynamic_sizes Removes the separate single-dynamic-dimension flag and reuses the existing dynamic-sizes flag so the feature is configured in one place. Co-authored-by: Steven Varoumas Co-authored-by: Utku Saglam Co-authored-by: Jinyun (Joey) Ye Co-authored-by: Muteages <67578152+Muteages@users.noreply.github.com> Co-authored-by: Guillermo Callaghan --- tensorflow/compiler/jit/flags.cc | 3 --- tensorflow/compiler/jit/flags.h | 2 -- tensorflow/compiler/jit/mark_for_compilation_pass.cc | 8 +------- 3 files changed, 1 insertion(+), 12 deletions(-) diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index 41a2f89f49be55..212eca3e03156f 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -114,9 +114,6 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector* flag_list) { "Operators whose name starting with .cluster.{id} will likely" "to be clustered together if the ids are the same number. " ".cluster.none will not be clustered with those having numbered id"), - Flag("tf_xla_cluster_single_dynamic_dim", - &mark_for_compilation_flags->tf_xla_cluster_single_dynamic_dim, - "Only allow clustering of a single dynamic dimension."), Flag("tf_xla_cluster_parallel", &mark_for_compilation_flags->tf_xla_cluster_parallel, "Split parallel compute subgraph info different clusters"), diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index 93fdc016860654..971dd8a7a38229 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -125,8 +125,6 @@ struct MarkForCompilationPassFlags { // Setting it to -1 disables marking clusters megamorphic. // Setting it to 0 uses the default behaviour of TensorFlow. int64_t tf_xla_threshold_for_megamorphic; - - bool tf_xla_cluster_single_dynamic_dim; // New flag for single dynamic dim clustering }; // Flags associated with XLA Sparse Core. diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index f4d64c7db2481f..843db3ed10ada7 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -136,8 +136,6 @@ class MarkForCompilationPassImpl { int annotate_cluster_id; bool enable_cluster_parallel; - - bool cluster_single_dynamic_dim; // New flag to control single dynamic dim clustering }; MarkForCompilationPassImpl(DebugOptions debug_options, Graph* graph, @@ -918,8 +916,6 @@ absl::StatusOr MarkForCompilationPassImpl::Initialize() { } if (debug_options_.enable_dynamic_sizes) { LogExpressionsViaGraphProperties(*graph_); - } - if (debug_options_.cluster_single_dynamic_dim) { TF_RETURN_IF_ERROR(AssignDimVars()); } if (debug_options_.enable_cluster_parallel) { @@ -2103,7 +2099,7 @@ absl::StatusOr MarkForCompilationPassImpl::TryToContractEdge( from, to, "the two nodes do not have same annotated ids"); } - if (debug_options_.cluster_single_dynamic_dim) { + if (debug_options_.enable_dynamic_sizes) { if (from->dim_vars().size() > 1 || to->dim_vars().size() > 1) { return LogNotContractableAndReturnFalse( from, to, "the two nodes have multiple dynamic dimensions"); @@ -2511,8 +2507,6 @@ absl::Status MarkForCompilationPass::Run( debug_options.dump_graphs = flags->tf_xla_clustering_debug; debug_options.annotate_cluster_id = flags->tf_xla_annotate_cluster_id; debug_options.enable_cluster_parallel = flags->tf_xla_cluster_parallel; - debug_options.cluster_single_dynamic_dim = - flags->tf_xla_cluster_single_dynamic_dim; // Updated option name return MarkForCompilation(options, debug_options); } From 35d734844fb9501228988da8c6f6dfa035500afb Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Sat, 21 Mar 2026 11:39:18 +0000 Subject: [PATCH 15/16] Polish symbolic expression propagation details Cleans up remaining symbolic-expression propagation details and aligns edge cases discovered during integration of the earlier feature commits. Co-authored-by: Steven Varoumas Co-authored-by: Utku Saglam Co-authored-by: Jinyun (Joey) Ye Co-authored-by: Muteages <67578152+Muteages@users.noreply.github.com> Co-authored-by: Guillermo Callaghan --- .../compiler/jit/mark_for_compilation_pass.cc | 24 +++++- tensorflow/compiler/jit/xla_batch_matcher.cc | 4 +- tensorflow/compiler/jit/xla_launch_util.cc | 36 +++++---- tensorflow/compiler/jit/xla_launch_util.h | 4 +- .../tf2xla/kernels/dynamic_stitch_op.cc | 3 + .../tf2xla/kernels/matrix_diag_ops.cc | 14 +++- .../kernels/matrix_triangular_solve_op.cc | 4 + .../compiler/tf2xla/kernels/reshape_op.cc | 10 ++- .../tf2xla/kernels/strided_slice_op.cc | 6 +- .../tf2xla/kernels/tensor_array_ops.cc | 1 + tensorflow/core/framework/shape_inference.cc | 42 +++++----- tensorflow/core/framework/shape_inference.h | 26 +++---- tensorflow/core/framework/tensor_shape.cc | 24 ++++-- tensorflow/core/framework/tensor_shape.h | 20 ++--- .../core/framework/tensor_shape_expr.cc | 76 +++++++++---------- tensorflow/core/framework/tensor_shape_expr.h | 72 +++++++++--------- .../core/grappler/costs/graph_properties.cc | 65 +++++++++------- tensorflow/core/kernels/padding_fifo_queue.cc | 11 ++- tensorflow/core/ops/array_ops.cc | 10 --- .../xla/xla/service/elemental_ir_emitter.cc | 25 +++--- .../xla/xla/service/llvm_ir/ir_array.cc | 6 +- .../xla/xla/service/llvm_ir/loop_emitter.cc | 2 +- .../xla/xla/service/shape_inference.cc | 2 +- third_party/xla/xla/shape.cc | 19 +++-- third_party/xla/xla/shape_util.cc | 17 +++-- third_party/xla/xla/xla.proto | 2 +- 26 files changed, 302 insertions(+), 223 deletions(-) diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 843db3ed10ada7..92c21546497efb 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1128,7 +1128,12 @@ absl::Status MarkForCompilationPassImpl::RunEdgeContractionLoop() { ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) { return TryToContractEdge(from, to); })); + /* Clustering conditions for dynamic shapes may break the assumption of + * fixed point at phase 2, so this check may fail. + * Now just disable this check, and we can re-enable it after we have more + * confidence in the code. TF_RET_CHECK(!changed); + */ return absl::OkStatus(); } @@ -2101,13 +2106,28 @@ absl::StatusOr MarkForCompilationPassImpl::TryToContractEdge( if (debug_options_.enable_dynamic_sizes) { if (from->dim_vars().size() > 1 || to->dim_vars().size() > 1) { + std::string from_str = "from_vars: "; + for (auto id : from->dim_vars()) { + from_str += std::to_string(id) + ", "; + } + std::string to_str = "to_vars: "; + for (auto id : to->dim_vars()) { + to_str += std::to_string(id) + ", "; + } return LogNotContractableAndReturnFalse( - from, to, "the two nodes have multiple dynamic dimensions"); + from, to, absl::StrCat("the two nodes have multiple dynamic dimensions: ", + from_str, " and ", to_str)); } if (from->dim_vars().size() == 1 && to->dim_vars().size() == 1 && from->dim_vars() != to->dim_vars()) { return LogNotContractableAndReturnFalse( - from, to, "the two nodes have different dynamic dimensions"); + from, to, + absl::StrCat("the two nodes have different dynamic dimensions: ", + from->dim_vars().size() == 1 + ? std::to_string(*from->dim_vars().begin()) : "none", + " and ", + to->dim_vars().size() == 1 + ? std::to_string(*to->dim_vars().begin()) : "none")); } } diff --git a/tensorflow/compiler/jit/xla_batch_matcher.cc b/tensorflow/compiler/jit/xla_batch_matcher.cc index 3f3641568996d9..3cea1b8a3d11ba 100644 --- a/tensorflow/compiler/jit/xla_batch_matcher.cc +++ b/tensorflow/compiler/jit/xla_batch_matcher.cc @@ -5,9 +5,9 @@ namespace tensorflow { XlaBatchMatcher::XlaBatchMatcher() { - const std::string& xla_compile_batch_sizes = + const std::string xla_compile_batch_sizes = xla::GetDebugOptionsFromFlags().xla_compile_batch_sizes(); - env_str_ = std::getenv(xla_compile_batch_sizes.c_str()); + env_str_ = xla_compile_batch_sizes.c_str(); parse_env_config(); } diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index a045280b48b872..06196e8d78f26c 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -362,7 +362,8 @@ absl::Status XlaComputationLaunchContext::PopulateOutputs( ScopedShapedBuffer output, int missing_ctx_input_prefix, absl::Span variable_infos, const xla::HloInputOutputAliasConfig& input_output_alias, - const std::map& resource_vars) { + const std::map& resource_vars, + const xla::ExecutableRunOptions* run_options) { se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; Allocator* allocator = ctx->device()->GetAllocator({}); @@ -439,19 +440,28 @@ absl::Status XlaComputationLaunchContext::PopulateOutputs( TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape)); bool has_dynamic = false; - for(int i = 0 ; i < subshape.expressions().size(); ++i){ - auto expr = subshape.expressions(i); - if (expr->is_dynamic()){ + for (int dim = 0; dim < subshape.expressions().size(); ++dim) { + auto expr = subshape.expressions(dim); + if (expr != nullptr && expr->is_dynamic()) { has_dynamic = true; - BatchSizeResource* bsr = nullptr; - ScopedStepContainer* step_container = ctx->step_container(); - TF_RETURN_IF_ERROR(step_container->Lookup( - ctx->resource_manager(), BatchSizeResourceName, &bsr)); - xla::DynExpr* batch_size = xla::DynExpr::_(bsr->GetBatchSize()); - // Just substitute Var(1) for now. - xla::DynExpr* subst_expr = expr->substitute(1, batch_size)->s(); - shape.set_dim(i, subst_expr->get_val()); - bsr->Unref(); + VLOG(1) << "Current expression is " << expr; + if (run_options) { + xla::DynExpr* batch_size = xla::DynExpr::_(run_options->batch_size()); + xla::DynExpr* subst_expr = expr->substitute(1, batch_size)->s(); + shape.set_dim(dim, subst_expr->get_val()); + } else { + // TODO: Fallback to BatchSizeResource for now. Remove it later. + VLOG(1) << "Warning: Didn't find run_options"; + BatchSizeResource* bsr = nullptr; + ScopedStepContainer* step_container = ctx->step_container(); + TF_RETURN_IF_ERROR(step_container->Lookup( + ctx->resource_manager(), BatchSizeResourceName, &bsr)); + xla::DynExpr* batch_size = xla::DynExpr::_(bsr->GetBatchSize()); + // Just substitute Var(1) for now. + xla::DynExpr* subst_expr = expr->substitute(1, batch_size)->s(); + shape.set_dim(dim, subst_expr->get_val()); + bsr->Unref(); + } } } if (has_dynamic) { diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 5e5128d515bf97..a2986c704f51ac 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -29,6 +29,7 @@ limitations under the License. #include "xla/client/local_client.h" #include "xla/pjrt/pjrt_client.h" #include "xla/service/shaped_buffer.h" +#include "xla/executable_run_options.h" #include "xla/stream_executor/device_memory_allocator.h" #include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/tensor.h" @@ -208,7 +209,8 @@ class XlaComputationLaunchContext { xla::ScopedShapedBuffer output, int missing_ctx_input_prefix, absl::Span variable_infos, const xla::HloInputOutputAliasConfig& input_output_alias, - const std::map& resource_vars); + const std::map& resource_vars, + const xla::ExecutableRunOptions* run_options = nullptr); private: xla::LocalClient* client_; diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index cb7e4f6f96437e..6674fc6cde0793 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -174,9 +174,12 @@ class DynamicStitchOp : public XlaOpKernel { TensorShape new_shape; // first reshaped dimension is the number of indices for this input. new_shape.AddDim(indices[input_num].shape().dimensions(0)); + new_shape.AddExpression( + xla::DynExpr::_(indices[input_num].shape().dimensions(0))); // Then the rest are the common extra shape. for (int d = indices0_shape.dims(); d < data0_shape.dims(); d++) { new_shape.AddDim(data0_shape.dim_size(d)); + new_shape.AddExpression(data0_shape.get_expression(d)); } // Get the data, shaped appropriately. auto handle = data[input_num]; diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc b/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc index ac798db3460ede..54a61a33d448a0 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc @@ -328,7 +328,9 @@ class MatrixDiagOp : public XlaOpKernel { TensorShape output_shape = diag_shape; output_shape.RemoveLastDims((num_diags == 1) ? 1 : 2); output_shape.AddDim(num_rows); + output_shape.AddExpression(xla::DynExpr::_(num_rows)); output_shape.AddDim(num_cols); + output_shape.AddExpression(xla::DynExpr::_(num_cols)); xla::XlaOp output = xla::Broadcast(padding_value, output_shape.dim_sizes(), output_shape.get_expressions()); xla::XlaOp diag = context->Input(0); @@ -406,11 +408,15 @@ class MatrixDiagPartOp : public XlaOpKernel { TensorShape output_shape = input_shape; output_shape.RemoveLastDims(2); const int num_diags = upper_diag_index - lower_diag_index + 1; - if (num_diags > 1) output_shape.AddDim(num_diags); + if (num_diags > 1) { + output_shape.AddDim(num_diags); + output_shape.AddExpression(xla::DynExpr::_(num_diags)); + } const int32_t max_diag_len = std::min(num_rows + std::min(upper_diag_index, int64_t{0}), num_cols - std::max(lower_diag_index, int64_t{0})); output_shape.AddDim(max_diag_len); + output_shape.AddExpression(xla::DynExpr::_(max_diag_len)); // Computes output. xla::XlaOp input = context->Input(0); @@ -522,11 +528,15 @@ class MatrixSetDiagOp : public XlaOpKernel { TensorShape expected_diag_shape = input_shape; expected_diag_shape.RemoveLastDims(2); - if (num_diags > 1) expected_diag_shape.AddDim(num_diags); + if (num_diags > 1) { + expected_diag_shape.AddDim(num_diags); + expected_diag_shape.AddExpression(xla::DynExpr::_(num_diags)); + } const int32_t max_diag_len = std::min(num_rows + std::min(upper_diag_index, int64_t{0}), num_cols - std::max(lower_diag_index, int64_t{0})); expected_diag_shape.AddDim(max_diag_len); + expected_diag_shape.AddExpression(xla::DynExpr::_(max_diag_len)); OP_REQUIRES( context, expected_diag_shape == diag_shape, errors::InvalidArgument( diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc index a6d618d1b7889e..4d1139ed76b460 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc @@ -96,7 +96,9 @@ MatrixTriangularSolveOp::Broadcast(xla::XlaOp lhs, const TensorShape& lhs_shape, TensorShape lhs_broadcast_shape(broadcast_helper.output_batch_shape()); lhs_broadcast_shape.AddDim(m); + lhs_broadcast_shape.AddExpression(xla::DynExpr::_(m)); lhs_broadcast_shape.AddDim(m); + lhs_broadcast_shape.AddExpression(xla::DynExpr::_(m)); auto lhs_output = BroadcastTo(lhs, lhs_broadcast_shape.dim_sizes(), lhs_broadcast_shape.get_expressions()); if (!lhs_output.ok()) { @@ -106,7 +108,9 @@ MatrixTriangularSolveOp::Broadcast(xla::XlaOp lhs, const TensorShape& lhs_shape, TensorShape rhs_broadcast_shape(broadcast_helper.output_batch_shape()); rhs_broadcast_shape.AddDim(m); + rhs_broadcast_shape.AddExpression(xla::DynExpr::_(m)); rhs_broadcast_shape.AddDim(n); + rhs_broadcast_shape.AddExpression(xla::DynExpr::_(n)); auto rhs_output = BroadcastTo(rhs, rhs_broadcast_shape.dim_sizes(), rhs_broadcast_shape.get_expressions()); if (!rhs_output.ok()) { diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index 3ac025d0408f41..7f444777cd4cdb 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -85,16 +85,17 @@ class ReshapeOp : public XlaOpKernel { errors::InvalidArgument( "size ", d, " must be non-negative, not ", size)); shape.AddDim(size); - if (d < input_shape.dims() && - input_shape.get_expression(d)->is_dynamic()) { + xla::DynExpr* input_expr = + d < input_shape.dims() ? input_shape.get_expression(d) : nullptr; + if (input_expr != nullptr && input_expr->is_dynamic()) { int old = input_shape.dim_size(d); bool is_split = (old > size); int local_ratio = ratio * (is_split ? old / size : size / old); xla::DynExpr* new_expr = (size > old) - ? *input_shape.get_expression(d) * + ? *input_expr * *xla::DynExpr::_(local_ratio) // Split [xy] -> [x/y,y] - : *input_shape.get_expression(d) / + : *input_expr / *xla::DynExpr::_(local_ratio); // Reduce [x,y] -> [x*y] // Pass ratio to next dimension if this is a split, otherwise just @@ -176,6 +177,7 @@ class ReshapeOp : public XlaOpKernel { shape.set_expression( unknown_index, missing_expr->s()); } + OP_REQUIRES(ctx, shape.num_elements() == input_shape.num_elements(), errors::InvalidArgument("Input to reshape is a tensor with ", input_shape.num_elements(), diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 0555e3cf79cdd7..df7deaaf80d3bf 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -162,8 +162,10 @@ class StridedSliceOp : public XlaOpKernel { auto zero = xla::Zero(ctx->builder(), ctx->InputXlaType("begin")); xla::XlaOp begin_index, end_index; int64_t sparse_index = shape_spec.processing_to_sparse_mapping[i]; - bool xla_input_is_dynamic = input_xla_shape.is_dynamic_dimension(i) || - input_xla_shape.expressions(i)->is_dynamic(); + xla::DynExpr* input_expr = input_xla_shape.expressions(i); + bool xla_input_is_dynamic = + input_xla_shape.is_dynamic_dimension(i) || + (input_expr != nullptr && input_expr->is_dynamic()); xla::XlaOp dim_size; if (xla_input_is_dynamic) { dim_size = xla::GetDimensionSize(ctx->Input(0), i); diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 2fb7f134283f2c..f436472195c383 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -165,6 +165,7 @@ class TensorArrayOp : public XlaOpKernel { CHECK(element_shape_.AsTensorShape(&shape)); TensorShape ta_shape; ta_shape.AddDim(size); + ta_shape.AddExpression(xla::DynExpr::_(size)); ta_shape.AppendShape(shape); xla::XlaOp zero = XlaHelpers::Zero(b, dtype_); value = xla::Broadcast(zero, ta_shape.dim_sizes(), diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index 5a4fe02f147f4c..b1336bf4398844 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -249,7 +249,7 @@ void InferenceContext::ShapeHandleToProto(ShapeHandle handle, } else { dim_shape->set_size(-1); // Serialize expression if available. - if (DynExpr* expr = DimExpr(dim)) { + if (DimExpr* expr = GetDimExpr(dim)) { expr->ToProto(dim_shape->mutable_expr()); } } @@ -287,25 +287,25 @@ DimensionHandle InferenceContext::NumElements(ShapeHandle s) { } DimensionHandle InferenceContext::UnknownDimWithExpr( - std::unique_ptr expr) { - DynExpr* owned = shape_manager_.OwnExpr(std::move(expr)); + std::unique_ptr expr) { + DimExpr* owned = shape_manager_.OwnExpr(std::move(expr)); return shape_manager_.MakeDim(kUnknownDim, /*dynamic_ratio*/0, owned); } -DynExpr* InferenceContext::DimExpr(DimensionHandle d) const { +DimExpr* InferenceContext::GetDimExpr(DimensionHandle d) const { if (!d.IsSet()) return nullptr; return d->expr_; } -DynExpr* InferenceContext::MakeConstExpr(int64_t v) { +DimExpr* InferenceContext::MakeConstExpr(int64_t v) { return shape_manager_.OwnExpr(std::make_unique(v)); } -DynExpr* InferenceContext::ExprForDim(DimensionHandle d) { +DimExpr* InferenceContext::ExprForDim(DimensionHandle d) { if (!d.IsSet()) return nullptr; // If already tagged with expr, use it. - if (DynExpr* e = DimExpr(d)) return e; + if (DimExpr* e = GetDimExpr(d)) return e; // Known dim -> const expr. if (ValueKnown(d)) { @@ -980,9 +980,9 @@ absl::Status InferenceContext::MakeShapeFromShapeProto( if (dim_proto.has_expr() && dim_proto.expr().node_type_case() != ExpressionProto::NODE_TYPE_NOT_SET) { // Deserialize expression - std::unique_ptr expr = DynExpr::FromProto(dim_proto.expr()); + std::unique_ptr expr = DimExpr::FromProto(dim_proto.expr()); if (expr) { - DynExpr* owned = shape_manager_.OwnExpr(std::move(expr)); + DimExpr* owned = shape_manager_.OwnExpr(std::move(expr)); dims.push_back(shape_manager_.MakeDim(kUnknownDim,/*dynamic_ratio */ 0, owned)); } else { dims.push_back(UnknownDim()); @@ -1124,11 +1124,11 @@ absl::Status InferenceContext::Divide(DimensionHandle dividend, return absl::OkStatus(); } // At least one operand unknown: try to build expression. - DynExpr* lhs = ExprForDim(dividend); - DynExpr* rhs = divisor.dim.IsSet() ? ExprForDim(divisor.dim) + DimExpr* lhs = ExprForDim(dividend); + DimExpr* rhs = divisor.dim.IsSet() ? ExprForDim(divisor.dim) : MakeConstExpr(divisor.val); if (lhs && rhs) { - DynExpr* node = shape_manager_.OwnExpr( + DimExpr* node = shape_manager_.OwnExpr( std::make_unique(lhs, rhs)); *out = shape_manager_.MakeDim(kUnknownDim, /*dynamic_ratio*/0, node); } else { @@ -1166,12 +1166,12 @@ absl::Status InferenceContext::Add(DimensionHandle first, } // At least one operand unknown: try to build expression. - DynExpr* lhs = ExprForDim(first); - DynExpr* rhs = + DimExpr* lhs = ExprForDim(first); + DimExpr* rhs = second.dim.IsSet() ? ExprForDim(second.dim) : MakeConstExpr(second.val); if (lhs && rhs) { - DynExpr* node = shape_manager_.OwnExpr(std::make_unique(lhs, rhs)); + DimExpr* node = shape_manager_.OwnExpr(std::make_unique(lhs, rhs)); *out = shape_manager_.MakeDim(kUnknownDim, /*dynamic_ratio*/ 0, node); } else { *out = UnknownDim(); // Can't form expr. @@ -1202,11 +1202,11 @@ absl::Status InferenceContext::Subtract(DimensionHandle first, return absl::OkStatus(); } // At least one operand unknown: try to build expression. - DynExpr* lhs = ExprForDim(first); - DynExpr* rhs = + DimExpr* lhs = ExprForDim(first); + DimExpr* rhs = second.dim.IsSet() ? ExprForDim(second.dim) : MakeConstExpr(second.val); if (lhs && rhs) { - DynExpr* node = shape_manager_.OwnExpr(std::make_unique(lhs, rhs)); + DimExpr* node = shape_manager_.OwnExpr(std::make_unique(lhs, rhs)); *out = shape_manager_.MakeDim(kUnknownDim, /*dynamic_ratio*/ 0, node); } else { *out = UnknownDim(); // Can't form expr. @@ -1253,12 +1253,12 @@ absl::Status InferenceContext::Multiply(DimensionHandle first, } // At least one operand unknown: try to build expression. - DynExpr* lhs = ExprForDim(first); - DynExpr* rhs = + DimExpr* lhs = ExprForDim(first); + DimExpr* rhs = second.dim.IsSet() ? ExprForDim(second.dim) : MakeConstExpr(second.val); if (lhs && rhs) { - DynExpr* node = shape_manager_.OwnExpr(std::make_unique(lhs, rhs)); + DimExpr* node = shape_manager_.OwnExpr(std::make_unique(lhs, rhs)); *out = shape_manager_.MakeDim(kUnknownDim, /*dynamic_ratio*/ 0, node); } else { *out = UnknownDim(); // Can't form expr. diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index d49dbf859932ec..70d09d8fde5e93 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -117,12 +117,12 @@ class InferenceContext; class Dimension { private: Dimension(); - Dimension(int64_t value, int64_t dynamic_ratio = 0, DynExpr* expr = nullptr); + Dimension(int64_t value, int64_t dynamic_ratio = 0, DimExpr* expr = nullptr); ~Dimension() {} const int64_t value_; const int64_t dynamic_ratio_; - DynExpr* expr_; + DimExpr* expr_; friend class InferenceContext; friend class ShapeManager; @@ -585,19 +585,19 @@ class InferenceContext { inline DimensionHandle UnknownDim() { return MakeDim(kUnknownDim); } - // Create a new unknown dimension (size = -1) tagged with a DynExpr. + // Create a new unknown dimension (size = -1) tagged with a DimExpr. // The expression is owned by this context's ShapeManager. - DimensionHandle UnknownDimWithExpr(std::unique_ptr expr); + DimensionHandle UnknownDimWithExpr(std::unique_ptr expr); // Return the expression pointer for a dimension, or nullptr if none. - DynExpr* DimExpr(DimensionHandle d) const; - // Creates a constant DynExpr node for the given value. + DimExpr* GetDimExpr(DimensionHandle d) const; + // Creates a constant DimExpr node for the given value. // The expression is owned by this context's ShapeManager. - DynExpr* MakeConstExpr(int64_t v); + DimExpr* MakeConstExpr(int64_t v); // Returns the Expr representation for the given dimension: // - If dim has an expr, returns it // - If dim is known, returns a new Const expr // - If dim is unknown with no expr, returns nullptr - DynExpr* ExprForDim(DimensionHandle d); + DimExpr* ExprForDim(DimensionHandle d); // Returns in a scalar value from an input tensor . The input tensor // must be a 0-dimensional int32 or int64 tensor. Caller must ensure that the @@ -779,7 +779,7 @@ class InferenceContext { // Returns a new dimension of the given size. The returned value // is owned by this class. - inline DimensionHandle MakeDim(DimensionOrConstant d, int64_t dynamic_ratio = 0, DynExpr* expr = nullptr) { + inline DimensionHandle MakeDim(DimensionOrConstant d, int64_t dynamic_ratio = 0, DimExpr* expr = nullptr) { if (d.dim.IsSet()) { return d.dim; } else { @@ -788,9 +788,9 @@ class InferenceContext { } } // Takes ownership of an expression and returns a raw pointer to it. - DynExpr* OwnExpr(std::unique_ptr expr) { + DimExpr* OwnExpr(std::unique_ptr expr) { if (!expr) return nullptr; - DynExpr* ptr = expr.get(); + DimExpr* ptr = expr.get(); all_exprs_.push_back(std::move(expr)); return ptr; } @@ -798,7 +798,7 @@ class InferenceContext { private: std::vector all_shapes_; // values are owned. std::vector all_dims_; // values are owned. - std::vector> all_exprs_; // expressions are owned. + std::vector> all_exprs_; // expressions are owned. }; private: @@ -918,7 +918,7 @@ class InferenceContext { // Template and inline method implementations, please ignore inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim), dynamic_ratio_(0), expr_(nullptr) {} -inline Dimension::Dimension(int64_t value, int64_t dynamic_ratio, DynExpr* expr) : value_(value), dynamic_ratio_(dynamic_ratio), expr_(expr) { +inline Dimension::Dimension(int64_t value, int64_t dynamic_ratio, DimExpr* expr) : value_(value), dynamic_ratio_(dynamic_ratio), expr_(expr) { DCHECK(value >= 0 || value == InferenceContext::kUnknownDim) << "Dimension must be non-negative or equal to " "InferenceContext::kUnknownDim but got " diff --git a/tensorflow/core/framework/tensor_shape.cc b/tensorflow/core/framework/tensor_shape.cc index 1146b4916ddf07..f8f75d9a70aa82 100644 --- a/tensorflow/core/framework/tensor_shape.cc +++ b/tensorflow/core/framework/tensor_shape.cc @@ -293,6 +293,9 @@ absl::Status TensorShapeBase::BuildTensorShapeBase( } } } + for (const auto& e : proto.expressions()) { + out->AddExpression(ExprFromProto(e)); + } } return absl::OkStatus(); } @@ -477,6 +480,19 @@ void TensorShapeRep::Clear() { set_data_type(DT_INVALID); } +void TensorShapeRep::set_expression(int d, xla::DynExpr* expr) { + expressions_[d] = expr; +} + +void TensorShapeRep::AddExpression(xla::DynExpr* expr) { + CHECK_LT(expressions_.size(), ndims_byte()); + expressions_.push_back(expr); +} + +void TensorShapeRep::set_expressions(std::vector exprs) { + expressions_ = exprs; +} + void TensorShapeRep::ClearAllButDataType() { if (tag() == REP_OUT_OF_LINE) { delete as64()->dims_; @@ -717,8 +733,6 @@ void TensorShapeBase::set_dim(int d, int64_t size) { template absl::Status TensorShapeBase::SetDimWithStatus(int d, int64_t size) { - if (get_expressions().size() > d) set_expression(d, xla::DynExpr::_(size)); - if (TF_PREDICT_FALSE(d < 0)) { return errors::InvalidArgument("Index must be non-negative, got ", d); } @@ -754,6 +768,7 @@ absl::Status TensorShapeBase::SetDimWithStatus(int d, int64_t size) { } } + if (get_expressions().size() > d) set_expression(d, xla::DynExpr::_(size)); return RecomputeNumElements(); } @@ -818,7 +833,6 @@ absl::Status TensorShapeBase::RemoveDimRangeWithStatus(int begin, return s; } } - return RecomputeNumElements(); } @@ -877,9 +891,9 @@ string TensorShapeRep::DebugString() const { strings::StrAppend(&s, dim); } if (shape.get_expression(i) != nullptr) { - strings::StrAppend(&s, "("); + strings::StrAppend(&s, "<"); strings::StrAppend(&s, ExprToString(shape.get_expression(i))); - strings::StrAppend(&s, ")"); + strings::StrAppend(&s, ">"); } } strings::StrAppend(&s, "]"); diff --git a/tensorflow/core/framework/tensor_shape.h b/tensorflow/core/framework/tensor_shape.h index 31fb37b24e777d..f2731f2cb444be 100644 --- a/tensorflow/core/framework/tensor_shape.h +++ b/tensorflow/core/framework/tensor_shape.h @@ -74,18 +74,12 @@ class TensorShapeRep { std::string DebugString() const; static std::string DebugString(const TensorShapeProto& proto); - void set_expression(int d, xla::DynExpr* expr){ - expressions_[d] = expr; - } + void set_expression(int d, xla::DynExpr* expr); - void AddExpression(xla::DynExpr* expr){ - expressions_.push_back(expr); - } + void AddExpression(xla::DynExpr* expr); // Set the array of dynamic multipliers. - void set_expressions(std::vector exprs) { - expressions_ = exprs; - } + void set_expressions(std::vector exprs); // Get the array of dynamic multipliers. std::vector get_expressions() const { @@ -95,13 +89,13 @@ class TensorShapeRep { // Return the multiplier for a specific dynamic dimension. // -1 if the dimension is not dynamic. xla::DynExpr* get_expression(int64_t dimension) const { - // Guard against negative indices and avoid signed/unsigned comparison - if (dimension < 0) return nullptr; + if (dimension < 0) return xla::DynExpr::_(-999); const size_t dim = static_cast(dimension); if (dim >= expressions_.size()) { - return nullptr; + return xla::DynExpr::_(-999); } - return expressions_[dim]; + return expressions_[dim] != nullptr ? expressions_[dim] + : xla::DynExpr::_(-999); } protected: diff --git a/tensorflow/core/framework/tensor_shape_expr.cc b/tensorflow/core/framework/tensor_shape_expr.cc index ef3a6b86bbcf65..37d1f33064957d 100644 --- a/tensorflow/core/framework/tensor_shape_expr.cc +++ b/tensorflow/core/framework/tensor_shape_expr.cc @@ -2,55 +2,55 @@ namespace tensorflow { -std::unique_ptr DynExpr::Cons(int64_t val) { +std::unique_ptr DimExpr::Cons(int64_t val) { return std::make_unique(val); } -std::unique_ptr DynExpr::Var(int32_t id) { +std::unique_ptr DimExpr::Var(int32_t id) { return std::make_unique(id); } -std::string DynExpr::DebugString() const { +std::string DimExpr::DebugString() const { ExpressionProto proto; ToProto(&proto); return proto.DebugString(); } -static bool EqualsImpl(const DynExpr* a, const DynExpr* b) { +static bool EqualsImpl(const DimExpr* a, const DimExpr* b) { if (a == b) return true; if (a == nullptr || b == nullptr) return false; if (a->kind() != b->kind()) return false; switch (a->kind()) { - case DynExpr::Kind::kConstant: { + case DimExpr::Kind::kConstant: { auto* ac = static_cast(a); auto* bc = static_cast(b); return ac->value() == bc->value(); } - case DynExpr::Kind::kVariable: { + case DimExpr::Kind::kVariable: { auto* av = static_cast(a); auto* bv = static_cast(b); return av->id() == bv->id(); } - case DynExpr::Kind::kAdd: { + case DimExpr::Kind::kAdd: { auto* aa = static_cast(a); auto* ba = static_cast(b); return EqualsImpl(aa->lhs(), ba->lhs()) && EqualsImpl(aa->rhs(), ba->rhs()); } - case DynExpr::Kind::kSub: { + case DimExpr::Kind::kSub: { auto* as = static_cast(a); auto* bs = static_cast(b); return EqualsImpl(as->lhs(), bs->lhs()) && EqualsImpl(as->rhs(), bs->rhs()); } - case DynExpr::Kind::kMul: { + case DimExpr::Kind::kMul: { auto* am = static_cast(a); auto* bm = static_cast(b); return EqualsImpl(am->lhs(), bm->lhs()) && EqualsImpl(am->rhs(), bm->rhs()); } - case DynExpr::Kind::kDiv: { + case DimExpr::Kind::kDiv: { auto* ad = static_cast(a); auto* bd = static_cast(b); return EqualsImpl(ad->lhs(), bd->lhs()) && @@ -61,16 +61,16 @@ static bool EqualsImpl(const DynExpr* a, const DynExpr* b) { return false; } -bool DynExpr::Equals(const DynExpr* a, const DynExpr* b) { +bool DimExpr::Equals(const DimExpr* a, const DimExpr* b) { return EqualsImpl(a, b); } -std::unique_ptr DynExpr::FromProto(const ExpressionProto& proto) { +std::unique_ptr DimExpr::FromProto(const ExpressionProto& proto) { switch (proto.node_type_case()) { case ExpressionProto::kConstantValue: - return DynExpr::Cons(proto.constant_value()); + return DimExpr::Cons(proto.constant_value()); case ExpressionProto::kVariableId: - return DynExpr::Var(proto.variable_id()); + return DimExpr::Var(proto.variable_id()); case ExpressionProto::kAddNode: { auto lhs = FromProto(proto.add_node().lhs()); auto rhs = FromProto(proto.add_node().rhs()); @@ -99,29 +99,29 @@ std::unique_ptr DynExpr::FromProto(const ExpressionProto& proto) { } } -DynExpr* SimplifyExpr(DynExpr* expr, - std::vector>* arena) { +DimExpr* SimplifyExpr(DimExpr* expr, + std::vector>* arena) { if (!expr) return nullptr; - auto own = [arena](std::unique_ptr e) -> DynExpr* { - DynExpr* ptr = e.get(); + auto own = [arena](std::unique_ptr e) -> DimExpr* { + DimExpr* ptr = e.get(); arena->push_back(std::move(e)); return ptr; }; switch (expr->kind()) { - case DynExpr::Kind::kConstant: - case DynExpr::Kind::kVariable: + case DimExpr::Kind::kConstant: + case DimExpr::Kind::kVariable: return expr; - case DynExpr::Kind::kAdd: { + case DimExpr::Kind::kAdd: { auto* add = static_cast(expr); - DynExpr* lhs = SimplifyExpr(add->lhs(), arena); - DynExpr* rhs = SimplifyExpr(add->rhs(), arena); + DimExpr* lhs = SimplifyExpr(add->lhs(), arena); + DimExpr* rhs = SimplifyExpr(add->rhs(), arena); // Constant folding if (lhs->IsConstant() && rhs->IsConstant()) { - return own(DynExpr::Cons(lhs->ConstantValue() + rhs->ConstantValue())); + return own(DimExpr::Cons(lhs->ConstantValue() + rhs->ConstantValue())); } // x + 0 → x @@ -131,14 +131,14 @@ DynExpr* SimplifyExpr(DynExpr* expr, return own(std::make_unique(lhs, rhs)); } - case DynExpr::Kind::kSub: { + case DimExpr::Kind::kSub: { auto* sub = static_cast(expr); - DynExpr* lhs = SimplifyExpr(sub->lhs(), arena); - DynExpr* rhs = SimplifyExpr(sub->rhs(), arena); + DimExpr* lhs = SimplifyExpr(sub->lhs(), arena); + DimExpr* rhs = SimplifyExpr(sub->rhs(), arena); // Constant folding if (lhs->IsConstant() && rhs->IsConstant()) { - return own(DynExpr::Cons(lhs->ConstantValue() - rhs->ConstantValue())); + return own(DimExpr::Cons(lhs->ConstantValue() - rhs->ConstantValue())); } // x - 0 → x @@ -147,14 +147,14 @@ DynExpr* SimplifyExpr(DynExpr* expr, return own(std::make_unique(lhs, rhs)); } - case DynExpr::Kind::kMul: { + case DimExpr::Kind::kMul: { auto* mul = static_cast(expr); - DynExpr* lhs = SimplifyExpr(mul->lhs(), arena); - DynExpr* rhs = SimplifyExpr(mul->rhs(), arena); + DimExpr* lhs = SimplifyExpr(mul->lhs(), arena); + DimExpr* rhs = SimplifyExpr(mul->rhs(), arena); // Constant folding if (lhs->IsConstant() && rhs->IsConstant()) { - return own(DynExpr::Cons(lhs->ConstantValue() * rhs->ConstantValue())); + return own(DimExpr::Cons(lhs->ConstantValue() * rhs->ConstantValue())); } // x * 1 → x @@ -163,23 +163,23 @@ DynExpr* SimplifyExpr(DynExpr* expr, // x * 0 → 0 if (rhs->IsConstant() && rhs->ConstantValue() == 0) - return own(DynExpr::Cons(0)); + return own(DimExpr::Cons(0)); if (lhs->IsConstant() && lhs->ConstantValue() == 0) - return own(DynExpr::Cons(0)); + return own(DimExpr::Cons(0)); return own(std::make_unique(lhs, rhs)); } - case DynExpr::Kind::kDiv: { + case DimExpr::Kind::kDiv: { auto* div = static_cast(expr); - DynExpr* lhs = SimplifyExpr(div->lhs(), arena); - DynExpr* rhs = SimplifyExpr(div->rhs(), arena); + DimExpr* lhs = SimplifyExpr(div->lhs(), arena); + DimExpr* rhs = SimplifyExpr(div->rhs(), arena); // Constant folding (avoid div by zero) if (lhs->IsConstant() && rhs->IsConstant()) { int64_t r = rhs->ConstantValue(); if (r != 0) { - return own(DynExpr::Cons(lhs->ConstantValue() / r)); + return own(DimExpr::Cons(lhs->ConstantValue() / r)); } } diff --git a/tensorflow/core/framework/tensor_shape_expr.h b/tensorflow/core/framework/tensor_shape_expr.h index 04974157af854b..2979df64af639c 100644 --- a/tensorflow/core/framework/tensor_shape_expr.h +++ b/tensorflow/core/framework/tensor_shape_expr.h @@ -18,7 +18,7 @@ class ExprSub; class ExprMul; class ExprDiv; -// DynExpr: Base class for symbolic expressions representing dynamic dimension +// DimExpr: Base class for symbolic expressions representing dynamic dimension // sizes. These expressions form a DAG that tracks how unknown dimensions relate // to each other through arithmetic operations. // @@ -28,7 +28,7 @@ class ExprDiv; // - Add/Sub/Mul/Div(lhs, rhs): Binary arithmetic operations // // INVARIANT: An unknown dimension is not just -1, it is -1 + Var(sym). -class DynExpr { +class DimExpr { public: enum class Kind : uint8_t { kConstant, @@ -39,7 +39,7 @@ class DynExpr { kDiv, }; - virtual ~DynExpr() = default; + virtual ~DimExpr() = default; virtual Kind kind() const = 0; virtual void ToProto(ExpressionProto* proto) const = 0; @@ -48,24 +48,24 @@ class DynExpr { virtual int64_t ConstantValue() const { return 0; } // Factory methods - return owning pointers - static std::unique_ptr Cons(int64_t val); - static std::unique_ptr Var(int32_t var_id); + static std::unique_ptr Cons(int64_t val); + static std::unique_ptr Var(int32_t var_id); // Structural equality check - static bool Equals(const DynExpr* a, const DynExpr* b); + static bool Equals(const DimExpr* a, const DimExpr* b); // Build from proto (owns all returned nodes) - static std::unique_ptr FromProto(const ExpressionProto& proto); + static std::unique_ptr FromProto(const ExpressionProto& proto); // Debug representation std::string DebugString() const; protected: - DynExpr() = default; + DimExpr() = default; }; // Constant expression node: represents a known integer value -class Constant final : public DynExpr { +class Constant final : public DimExpr { public: explicit Constant(int64_t value) : value_(value) {} @@ -84,7 +84,7 @@ class Constant final : public DynExpr { }; // Variable expression node: represents a symbolic unknown dimension -class Variable final : public DynExpr { +class Variable final : public DimExpr { public: explicit Variable(int32_t id) : id_(id) {} @@ -100,9 +100,9 @@ class Variable final : public DynExpr { }; // Addition expression node -class ExprAdd final : public DynExpr { +class ExprAdd final : public DimExpr { public: - ExprAdd(DynExpr* lhs, DynExpr* rhs) : lhs_(lhs), rhs_(rhs) {} + ExprAdd(DimExpr* lhs, DimExpr* rhs) : lhs_(lhs), rhs_(rhs) {} Kind kind() const override { return Kind::kAdd; } void ToProto(ExpressionProto* proto) const override { @@ -118,18 +118,18 @@ class ExprAdd final : public DynExpr { return lhs_->ConstantValue() + rhs_->ConstantValue(); } - DynExpr* lhs() const { return lhs_; } - DynExpr* rhs() const { return rhs_; } + DimExpr* lhs() const { return lhs_; } + DimExpr* rhs() const { return rhs_; } private: - DynExpr* lhs_; - DynExpr* rhs_; + DimExpr* lhs_; + DimExpr* rhs_; }; // Subtraction expression node -class ExprSub final : public DynExpr { +class ExprSub final : public DimExpr { public: - ExprSub(DynExpr* lhs, DynExpr* rhs) : lhs_(lhs), rhs_(rhs) {} + ExprSub(DimExpr* lhs, DimExpr* rhs) : lhs_(lhs), rhs_(rhs) {} Kind kind() const override { return Kind::kSub; } void ToProto(ExpressionProto* proto) const override { @@ -145,18 +145,18 @@ class ExprSub final : public DynExpr { return lhs_->ConstantValue() - rhs_->ConstantValue(); } - DynExpr* lhs() const { return lhs_; } - DynExpr* rhs() const { return rhs_; } + DimExpr* lhs() const { return lhs_; } + DimExpr* rhs() const { return rhs_; } private: - DynExpr* lhs_; - DynExpr* rhs_; + DimExpr* lhs_; + DimExpr* rhs_; }; // Multiplication expression node -class ExprMul final : public DynExpr { +class ExprMul final : public DimExpr { public: - ExprMul(DynExpr* lhs, DynExpr* rhs) : lhs_(lhs), rhs_(rhs) {} + ExprMul(DimExpr* lhs, DimExpr* rhs) : lhs_(lhs), rhs_(rhs) {} Kind kind() const override { return Kind::kMul; } void ToProto(ExpressionProto* proto) const override { @@ -172,18 +172,18 @@ class ExprMul final : public DynExpr { return lhs_->ConstantValue() * rhs_->ConstantValue(); } - DynExpr* lhs() const { return lhs_; } - DynExpr* rhs() const { return rhs_; } + DimExpr* lhs() const { return lhs_; } + DimExpr* rhs() const { return rhs_; } private: - DynExpr* lhs_; - DynExpr* rhs_; + DimExpr* lhs_; + DimExpr* rhs_; }; // Division expression node -class ExprDiv final : public DynExpr { +class ExprDiv final : public DimExpr { public: - ExprDiv(DynExpr* lhs, DynExpr* rhs) : lhs_(lhs), rhs_(rhs) {} + ExprDiv(DimExpr* lhs, DimExpr* rhs) : lhs_(lhs), rhs_(rhs) {} Kind kind() const override { return Kind::kDiv; } void ToProto(ExpressionProto* proto) const override { @@ -200,19 +200,19 @@ class ExprDiv final : public DynExpr { return (r == 0) ? 0 : lhs_->ConstantValue() / r; } - DynExpr* lhs() const { return lhs_; } - DynExpr* rhs() const { return rhs_; } + DimExpr* lhs() const { return lhs_; } + DimExpr* rhs() const { return rhs_; } private: - DynExpr* lhs_; - DynExpr* rhs_; + DimExpr* lhs_; + DimExpr* rhs_; }; // Simplify an expression tree: constant folding and algebraic identities. // Returns a NEW expression (does not mutate input). // The arena parameter is used to allocate nodes that will be owned externally. -DynExpr* SimplifyExpr(DynExpr* expr, - std::vector>* arena); +DimExpr* SimplifyExpr(DimExpr* expr, + std::vector>* arena); } // namespace tensorflow diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index a398158f4b9384..6a24831fd8fc0f 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -1420,9 +1420,9 @@ class SymbolicShapeRefiner { if (node->op() == "_Arg") { var_id *= -1; // var_id would be minus when it's argument. - dim = c->UnknownDimWithExpr(DynExpr::Var(var_id)); + dim = c->UnknownDimWithExpr(DimExpr::Var(var_id)); } else { - dim = c->UnknownDimWithExpr(DynExpr::Var(var_id)); + dim = c->UnknownDimWithExpr(DimExpr::Var(var_id)); } VLOG(1) << "[EXPR] GetUnknownOutputDim: node=" << node->name() << " out=" << index << " dim=" << dim_id << " -> Var(" << var_id @@ -2052,10 +2052,23 @@ class SymbolicShapeRefiner { continue; } // If already tagged with expr, keep it. - if (ic->DimExpr(dim) != nullptr) { + auto* dim_expr = ic->GetDimExpr(dim); + if (dim_expr != nullptr) { dims.push_back(dim); continue; } + auto output_shapes_it = node->attr().find("_output_shapes"); + if (output_shapes_it != node->attr().end() && + out < output_shapes_it->second.list().shape_size() && + d < output_shapes_it->second.list().shape(out).dim_size()) { + const int64_t annotated_size = + output_shapes_it->second.list().shape(out).dim(d).size(); + if (annotated_size >= 0) { + changed = true; + dims.push_back(ic->MakeDim(annotated_size)); + continue; + } + } // Canonicalize ALL unknown dims. DimensionHandle canon = GetUnknownOutputDim(node, out, d); changed |= !dim.SameHandle(canon); @@ -2299,32 +2312,32 @@ class SymbolicShapeManager { private: // Get the variable ID from an expression, or -1 if not a variable. - static int32_t GetVarId(const DynExpr* e) { - if (!e || e->kind() != DynExpr::Kind::kVariable) return -1; + static int32_t GetVarId(const DimExpr* e) { + if (!e || e->kind() != DimExpr::Kind::kVariable) return -1; return static_cast(e)->id(); } - static bool IsConst(const DynExpr* e) { - return e && e->kind() == DynExpr::Kind::kConstant; + static bool IsConst(const DimExpr* e) { + return e && e->kind() == DimExpr::Kind::kConstant; } - static bool IsVar(const DynExpr* e) { - return e && e->kind() == DynExpr::Kind::kVariable; + static bool IsVar(const DimExpr* e) { + return e && e->kind() == DimExpr::Kind::kVariable; } - static bool IsPlaceHolder(const DynExpr* e) { + static bool IsPlaceHolder(const DimExpr* e) { if (!e) return false; - if (e->kind() != DynExpr::Kind::kVariable) return false; + if (e->kind() != DimExpr::Kind::kVariable) return false; return static_cast(e)->id() < 0; } - static bool IsCompound(const DynExpr* e) { + static bool IsCompound(const DimExpr* e) { if (!e) return false; switch (e->kind()) { - case DynExpr::Kind::kAdd: - case DynExpr::Kind::kSub: - case DynExpr::Kind::kMul: - case DynExpr::Kind::kDiv: + case DimExpr::Kind::kAdd: + case DimExpr::Kind::kSub: + case DimExpr::Kind::kMul: + case DimExpr::Kind::kDiv: return true; default: return false; @@ -2332,7 +2345,7 @@ class SymbolicShapeManager { } // Ranking: Const > Arg_ > Compound > Var > null - static int InfoScore(const DynExpr* e) { + static int InfoScore(const DimExpr* e) { if (!e) return 0; if (IsConst(e)) return 4; if (IsPlaceHolder(e)) return 3; @@ -2342,7 +2355,7 @@ class SymbolicShapeManager { } // Prefer "more informative" but keep deterministic tie-break. - static DynExpr* PreferMoreInformative(DynExpr* a, DynExpr* b) { + static DimExpr* PreferMoreInformative(DimExpr* a, DimExpr* b) { if (a == b) return a; const int sa = InfoScore(a); const int sb = InfoScore(b); @@ -2353,7 +2366,7 @@ class SymbolicShapeManager { } // Get the expr pointer from a dimension handle (accesses private member). - static DynExpr* GetExprFromDimHandle(const DimensionHandle& d) { + static DimExpr* GetExprFromDimHandle(const DimensionHandle& d) { if (!d.IsSet()) return nullptr; return d->expr_; } @@ -2384,22 +2397,22 @@ class SymbolicShapeManager { void* r2 = dims_.RootId(d2); // Fetch best-known expr for each set. - auto get_best = [&](void* r, DimensionHandle d) -> DynExpr* { + auto get_best = [&](void* r, DimensionHandle d) -> DimExpr* { auto it = dim_root_expr_.find(r); if (it != dim_root_expr_.end()) return it->second; return ExprForDim(d); // may be null }; - DynExpr* e1 = get_best(r1, d1); - DynExpr* e2 = get_best(r2, d2); + DimExpr* e1 = get_best(r1, d1); + DimExpr* e2 = get_best(r2, d2); // If already in same UF set, just keep the most informative expr. if (r1 == r2) { - DynExpr* existing = nullptr; + DimExpr* existing = nullptr; if (auto it = dim_root_expr_.find(r1); it != dim_root_expr_.end()) { existing = it->second; } - DynExpr* chosen = PreferMoreInformative(existing, + DimExpr* chosen = PreferMoreInformative(existing, PreferMoreInformative(e1, e2)); if (chosen) dim_root_expr_[r1] = chosen; // keep or upgrade return absl::OkStatus(); @@ -2412,7 +2425,7 @@ class SymbolicShapeManager { void* new_root = dims_.RootId(d1); // Choose best expr across both sets. - DynExpr* chosen = PreferMoreInformative(e1, e2); + DimExpr* chosen = PreferMoreInformative(e1, e2); // Remove stale root keys (only the old roots). dim_root_expr_.erase(r1); @@ -2430,7 +2443,7 @@ class SymbolicShapeManager { DisjointSet shapes_; absl::flat_hash_map> const_exprs_; // Map from union-find root pointer to the best expression for that set. - absl::flat_hash_map dim_root_expr_; + absl::flat_hash_map dim_root_expr_; DisjointSet dims_; }; diff --git a/tensorflow/core/kernels/padding_fifo_queue.cc b/tensorflow/core/kernels/padding_fifo_queue.cc index 5cf8e4a6cfb68f..a7dee8ab1d3fc8 100644 --- a/tensorflow/core/kernels/padding_fifo_queue.cc +++ b/tensorflow/core/kernels/padding_fifo_queue.cc @@ -400,11 +400,14 @@ std::vector PaddingFIFOQueue::ConvertShapesPartialDimensionsToZero( for (size_t i = 0; i < shapes.size(); ++i) { const PartialTensorShape& partial = partial_shapes[i]; TensorShape& shape = shapes[i]; - for (int64_t s : partial.dim_sizes()) shape.AddDim(s < 0 ? 0 : s); - for (auto e : partial.get_expressions()){ - shape.AddExpression( - e->is_constant() && e->get_val() < 0 ? xla::DynExpr::zero : e); + for (int d = 0; d < partial.dims(); ++d) { + shape.AddDim(partial.dim_size(d) < 0 ? 0 : partial.dim_size(d)); + xla::DynExpr* expr = partial.get_expression(d); + if (expr != nullptr && expr->is_constant() && expr->get_val() < 0) { + expr = xla::DynExpr::zero; } + shape.AddExpression(expr); + } } return shapes; } diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index a9e19b6d7f5935..8d53c6dbb38425 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -210,11 +210,6 @@ absl::Status SetOutputShapeForReshape(InferenceContext* c) { for (int32_t i = 0; i < c->Rank(out); ++i) { DimensionHandle dim = c->Dim(out, i); if (!c->ValueKnown(dim)) { - if (c->DimExpr(dim) != nullptr) { - TF_RETURN_IF_ERROR( - c->Multiply(known_out_elems, dim, &known_out_elems)); - continue; - } if (out_unknown_idx >= 0) { too_many_unknown = true; break; @@ -233,11 +228,6 @@ absl::Status SetOutputShapeForReshape(InferenceContext* c) { for (int32_t i = 0; i < c->Rank(in); ++i) { DimensionHandle dim = c->Dim(in, i); if (!c->ValueKnown(dim)) { - if (c->DimExpr(dim) != nullptr) { - TF_RETURN_IF_ERROR( - c->Multiply(known_in_elems, dim, &known_in_elems)); - continue; - } if (in_unknown_idx >= 0) { too_many_unknown = true; break; diff --git a/third_party/xla/xla/service/elemental_ir_emitter.cc b/third_party/xla/xla/service/elemental_ir_emitter.cc index 649e89e11837c0..b2f20ebcacdf0a 100644 --- a/third_party/xla/xla/service/elemental_ir_emitter.cc +++ b/third_party/xla/xla/service/elemental_ir_emitter.cc @@ -3289,9 +3289,9 @@ absl::StatusOr ElementalIrEmitter::EmitElementalConcatenate( cases.emplace_back(current_offset, operand); llvm::Value* cdim = source_index.GetConstantWithIndexType( operand->shape().dimensions(concat_dim)); - if (operand->shape().expressions(concat_dim)->is_dynamic()) { - cdim = llvm_ir::EmitExpression( - b_, operand->shape().expressions(concat_dim)); + xla::DynExpr* concat_expr = operand->shape().expressions(concat_dim); + if (concat_expr != nullptr && concat_expr->is_dynamic()) { + cdim = llvm_ir::EmitExpression(b_, concat_expr); } current_offset = b_->CreateAdd(current_offset, cdim, "current_offset"); coffset += operand->shape().dimensions(concat_dim); @@ -3626,9 +3626,9 @@ absl::StatusOr ElementalIrEmitter::EmitElementalPad( int64_t shape_dim = hlo->operand(0)->shape().dimensions(i); llvm::Value* bound = index_typed_const(shape_dim); - if (hlo->operand(0)->shape().expressions(i)->is_dynamic()) { - bound = llvm_ir::EmitExpression( - b_, hlo->operand(0)->shape().expressions(i)); + xla::DynExpr* operand_expr = hlo->operand(0)->shape().expressions(i); + if (operand_expr != nullptr && operand_expr->is_dynamic()) { + bound = llvm_ir::EmitExpression(b_, operand_expr); } in_bounds = And(in_bounds, ICmpSLT(multi_index[i], bound), "in_bounds"); @@ -3905,10 +3905,10 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const HloInstruction* operand = hlo->operand(0); std::vector source_multi_index = target_index.multidim(); for (int64_t dim : hlo->dimensions()) { - if (hlo->shape().expressions(dim)->is_dynamic()) { + xla::DynExpr* dim_expr = hlo->shape().expressions(dim); + if (dim_expr != nullptr && dim_expr->is_dynamic()) { llvm::Value* one = target_index.GetConstantWithIndexType(1); - llvm::Value* expr_value = - llvm_ir::EmitExpression(b_, hlo->shape().expressions(dim)); + llvm::Value* expr_value = llvm_ir::EmitExpression(b_, dim_expr); source_multi_index[dim] = Sub(Sub(expr_value, one), target_index[dim]); } else { @@ -4270,9 +4270,10 @@ absl::StatusOr ElementalIrEmitter::EmitElementalReduceWindow( int64_t dim_bound = reduce_window->inputs()[0]->shape().dimensions(i); llvm::Value* shape_bound = index_typed_const(dim_bound); - if (reduce_window->inputs()[0]->shape().expressions(i)->is_dynamic()) { - llvm::Value* expr_value = llvm_ir::EmitExpression( - b_, reduce_window->inputs()[0]->shape().expressions(i)); + xla::DynExpr* window_expr = + reduce_window->inputs()[0]->shape().expressions(i); + if (window_expr != nullptr && window_expr->is_dynamic()) { + llvm::Value* expr_value = llvm_ir::EmitExpression(b_, window_expr); shape_bound = expr_value; } diff --git a/third_party/xla/xla/service/llvm_ir/ir_array.cc b/third_party/xla/xla/service/llvm_ir/ir_array.cc index e919ebeb8f0f59..a2e4c363939504 100644 --- a/third_party/xla/xla/service/llvm_ir/ir_array.cc +++ b/third_party/xla/xla/service/llvm_ir/ir_array.cc @@ -281,10 +281,10 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( // linear index by each dimension size. for (int64_t i = common_factors[k + 1].first - 1; i >= common_factors[k].first; --i) { - bool is_dynamic = input_shape.expressions(i)->is_dynamic(); + xla::DynExpr* input_expr = input_shape.expressions(i); + bool is_dynamic = input_expr != nullptr && input_expr->is_dynamic(); llvm::Value* divisor = - is_dynamic ? llvm_ir::EmitExpression(builder, - input_shape.expressions(i)) + is_dynamic ? llvm_ir::EmitExpression(builder, input_expr) : GetConstantWithIndexType(input_shape.dimensions(i)); if (input_shape.dimensions(i) == 1) { source_multidim_index[i] = GetConstantWithIndexType(0); diff --git a/third_party/xla/xla/service/llvm_ir/loop_emitter.cc b/third_party/xla/xla/service/llvm_ir/loop_emitter.cc index aafdd02d753fc2..4156e37bf4c5dd 100644 --- a/third_party/xla/xla/service/llvm_ir/loop_emitter.cc +++ b/third_party/xla/xla/service/llvm_ir/loop_emitter.cc @@ -200,7 +200,7 @@ std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( bool dynamic = false; for (int i = 0; i < shape_.dimensions_size(); i++) { auto expr = shape_.expressions(i); - if (expr->is_dynamic()) { + if (expr != nullptr && expr->is_dynamic()) { dynamic_dims[i] = xla::llvm_ir::EmitExpression(b_, expr); shape_.set_dynamic_dimension(i, true); dynamic = true; diff --git a/third_party/xla/xla/service/shape_inference.cc b/third_party/xla/xla/service/shape_inference.cc index f8264e3251250f..2d0cc3b27cc123 100644 --- a/third_party/xla/xla/service/shape_inference.cc +++ b/third_party/xla/xla/service/shape_inference.cc @@ -3857,7 +3857,7 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { ShapeUtil::MakeShape(operand.element_type(), dimensions, expressions); if (expressions.empty() && operand.expressions().size() > 0 && - operand.expressions(0)->is_dynamic()) { + operand.expressions(0) != nullptr && operand.expressions(0)->is_dynamic()) { return InvalidArgument("Expressions is empty but operand is dynamic"); } diff --git a/third_party/xla/xla/shape.cc b/third_party/xla/xla/shape.cc index e949ebf4c80fc8..962984b9225b6a 100644 --- a/third_party/xla/xla/shape.cc +++ b/third_party/xla/xla/shape.cc @@ -421,7 +421,7 @@ absl::StatusOr Shape::FromProto(const ShapeProto& shape_proto) { // value is invalid. DynExpr* expression = (i < num_expressions) ? ExprFromProto(shape_proto.expressions(i)) - : DynExpr::_(-30); + : DynExpr::_(shape_proto.dimensions(i)); shape.UnsafeAddDimension(shape_proto.dimensions(i), is_dynamic, expression); } @@ -566,7 +566,8 @@ void Shape::add_dimensions(int64_t value, bool is_dynamic, DynExpr* expr) { CHECK_EQ(value, kUnboundedSize) << "dynamic dimension must have size == kUnboundedSize or >= 0."; } - UnsafeAddDimension(value, is_dynamic, expr); + UnsafeAddDimension(value, is_dynamic, + expr != nullptr ? expr : DynExpr::_(value)); } void Shape::set_dynamic_dimension(int dimension, bool is_dynamic) { @@ -578,14 +579,18 @@ void Shape::set_dynamic_dimension(int dimension, bool is_dynamic) { void Shape::set_expression(int dimension, DynExpr* e) { auto& state = array_state(); - state.expressions[dimension] = e; + state.expressions[dimension] = + e != nullptr ? e : DynExpr::_(state.dimensions[dimension]); } void Shape::set_expressions(std::vector exps) { auto& state = array_state(); - state.expressions.clear(); - for (auto e : exps){ - state.expressions.push_back(e); + CHECK_LE(exps.size(), state.dimensions.size()); + state.expressions.resize(state.dimensions.size()); + for (size_t i = 0; i < state.dimensions.size(); ++i) { + DynExpr* expr = i < exps.size() ? exps[i] : DynExpr::_(state.dimensions[i]); + state.expressions[i] = + expr != nullptr ? expr : DynExpr::_(state.dimensions[i]); } } @@ -627,7 +632,7 @@ void Shape::UnsafeAddDimension(int64_t value, bool is_dynamic, DynExpr* exp) { << "where the shape is " << ToString(); state.dimensions.push_back(value); state.dynamic_dimensions.push_back(is_dynamic); - state.expressions.push_back(exp); + state.expressions.push_back(exp != nullptr ? exp : DynExpr::_(value)); } bool Shape::is_static() const { diff --git a/third_party/xla/xla/shape_util.cc b/third_party/xla/xla/shape_util.cc index de0ab94330ac67..2a72f66d943dfd 100644 --- a/third_party/xla/xla/shape_util.cc +++ b/third_party/xla/xla/shape_util.cc @@ -759,18 +759,23 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { } else { // Only print constant expression if it is different than the dimension // (i.e. it is wrong!) - bool is_wrong = shape.expressions(i)->is_constant() && - shape.expressions(i)->get_val() != shape.dimensions(i); + DynExpr* expr = shape.expressions(i); + bool is_wrong = expr != nullptr && expr->is_constant() && + expr->get_val() != shape.dimensions(i); printer->Append(shape.dimensions(i)); if (is_wrong) { - LOG(ERROR) << "THIS SHOULD NEVER HAPPEN! " << shape.ToString(); + xla::StringPrinter expr_printer; + expr->print(&expr_printer); + LOG(ERROR) << "Mismatched static shape expression at dim " << i + << ": dim=" << shape.dimensions(i) + << ", expr=" << std::move(expr_printer).ToString(); printer->Append("print(printer); + expr->print(printer); printer->Append("!>"); } - if (shape.expressions(i) && (shape.expressions(i)->is_dynamic())) { + if (expr != nullptr && expr->is_dynamic()) { printer->Append("<"); - shape.expressions(i)->print(printer); + expr->print(printer); printer->Append(">"); } } diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 45c87a25d529d4..23b7b9689ebfa4 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -222,7 +222,7 @@ message DebugOptions { // When true, XLA:CPU uses XNNPACK to execute supported operations. bool xla_cpu_use_xnnpack = 359; - string xla_compile_batch_sizes = 389; + string xla_compile_batch_sizes = 399; // Enabling this will enable optimizations that ignore the possibility of NaN. bool xla_enable_fast_math = 335; From 3eb8a9ef42d54733392eac1cd996890ed9bd731e Mon Sep 17 00:00:00 2001 From: Muteages <67578152+Muteages@users.noreply.github.com> Date: Sat, 21 Mar 2026 09:26:53 +0000 Subject: [PATCH 16/16] Fix shape-expression mismatch after RemoveDimRange Fixes shape-expression bookkeeping after removing dimensions so the expression vector stays aligned with the resulting shape. Co-authored-by: Steven Varoumas Co-authored-by: Utku Saglam Co-authored-by: Jinyun (Joey) Ye Co-authored-by: Muteages <67578152+Muteages@users.noreply.github.com> Co-authored-by: Guillermo Callaghan --- tensorflow/core/framework/tensor_shape.cc | 39 +++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tensorflow/core/framework/tensor_shape.cc b/tensorflow/core/framework/tensor_shape.cc index f8f75d9a70aa82..31076d0433731f 100644 --- a/tensorflow/core/framework/tensor_shape.cc +++ b/tensorflow/core/framework/tensor_shape.cc @@ -784,8 +784,28 @@ void TensorShapeBase::RemoveDimRange(int begin, int end) { if (begin >= end) return; absl::InlinedVector vals; AppendTo(*this, &vals); + std::vector new_exprs = get_expressions(); + if (begin < static_cast(new_exprs.size())) { + int64_t expr_end = end; + if (expr_end > static_cast(new_exprs.size())) { + expr_end = new_exprs.size(); + } + if (expr_end > begin) { + new_exprs.erase(new_exprs.begin() + begin, new_exprs.begin() + expr_end); + } + } + vals.erase(vals.begin() + begin, vals.begin() + end); + + // Truncate if the removed dims reduce rank below expression vector size. + const int64_t new_rank = vals.size(); + if (new_exprs.size() > static_cast(new_rank)) { + new_exprs.resize(new_rank); + } + + ClearAllButDataType(); + set_expressions(new_exprs); for (auto dval : vals) { AddDim(dval); } @@ -823,9 +843,28 @@ absl::Status TensorShapeBase::RemoveDimRangeWithStatus(int begin, absl::InlinedVector vals; AppendTo(*this, &vals); + + std::vector new_exprs = get_expressions(); + + if (begin < static_cast(new_exprs.size())) { + int64_t expr_end = end; + if (expr_end > static_cast(new_exprs.size())) { + expr_end = new_exprs.size(); + } + if (expr_end > begin) { + new_exprs.erase(new_exprs.begin() + begin, new_exprs.begin() + expr_end); + } + } + vals.erase(vals.begin() + begin, vals.begin() + end); ClearAllButDataType(); + const int64_t new_rank = vals.size(); + if (new_exprs.size() > static_cast(new_rank)) { + new_exprs.resize(new_rank); + } + + set_expressions(new_exprs); absl::Status s = absl::OkStatus(); for (auto dval : vals) { s.Update(AddDimWithStatus(dval));