diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 39f93d17aa2932..89a05cb80eba8f 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -780,6 +780,8 @@ cc_library( ":flags_headers", ":tf_graph_to_hlo_compiler", ":xla_compile_util", + ":xla_batch_matcher", + "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core:framework", "//tensorflow/core:framework_lite", @@ -2005,3 +2007,13 @@ tf_cuda_cc_test( "@local_xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", ], ) + +cc_library( + name = "xla_batch_matcher", + srcs = ["xla_batch_matcher.cc"], + hdrs = ["xla_batch_matcher.h"], + deps = [ + "//tensorflow/core/platform:logging", + "@local_xla//xla:debug_options_flags", + ], +) \ No newline at end of file diff --git a/tensorflow/compiler/jit/device_compilation_profiler.cc b/tensorflow/compiler/jit/device_compilation_profiler.cc index 5e1b3b26e8ecb5..f8a742ccc01148 100644 --- a/tensorflow/compiler/jit/device_compilation_profiler.cc +++ b/tensorflow/compiler/jit/device_compilation_profiler.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/strings/str_cat.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/xla_activity.pb.h" #include "tensorflow/compiler/jit/xla_activity_listener.h" #include "tensorflow/core/framework/attr_value.pb.h" @@ -32,9 +33,30 @@ limitations under the License. namespace tensorflow { namespace { bool ShouldBeMegamorphic(int64_t compile_count, int64_t execution_count) { - const int64_t kCompileThreshold = 10; + int64_t kCompileThreshold = 10; const int64_t kMinExecutionsPerCompile = 50; + int64_t tf_xla_threshold_for_megamorphic = + GetMarkForCompilationPassFlags()->tf_xla_threshold_for_megamorphic; + + // Negative values other that -1 cannot be used + if (tf_xla_threshold_for_megamorphic < -1) { + LOG(FATAL) << "The value for the tf_xla_threshold_for_megamorphic flag " + << "is out of range.\n" + << "Allowed ranges are (-1) to " + << std::numeric_limits::max() + << " got " << tf_xla_threshold_for_megamorphic << "."; + } + + // -1: setting clusters as Megamorphic is disabled + // 0 Default behaviour in Tensorflow + // Any other number sets the compilation threshold + if (tf_xla_threshold_for_megamorphic == -1) { + return false; + } else if (tf_xla_threshold_for_megamorphic > 0) { + kCompileThreshold = tf_xla_threshold_for_megamorphic; + } + // This heuristic is trying to capture the following property: have we sunk a // certain minimum amount of compile time into the cluster that didn't quite // "pay off"? diff --git a/tensorflow/compiler/jit/device_compiler.h b/tensorflow/compiler/jit/device_compiler.h index 34b22033129b96..28e27c696e788a 100644 --- a/tensorflow/compiler/jit/device_compiler.h +++ b/tensorflow/compiler/jit/device_compiler.h @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/tf_graph_to_hlo_compiler.h" #include "tensorflow/compiler/jit/xla_compile_util.h" +#include "tensorflow/compiler/jit/xla_batch_matcher.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/core/framework/metrics.h" #include "tensorflow/core/framework/op_kernel.h" @@ -125,7 +126,8 @@ class DeviceCompiler : public ResourceBase { DeviceCompilerClient* compiler_client() { return compiler_client_.get(); } - + XlaBatchMatcher* xla_batch_matcher() { return xla_batch_matcher_.get(); } + string DebugString() const override; private: @@ -177,6 +179,9 @@ class DeviceCompiler : public ResourceBase { // Pool of threads for asynchronous compilations. std::unique_ptr async_compiler_threads_; + // Specified dynamic batch padding values. + std::unique_ptr xla_batch_matcher_; + mutex cluster_mutexes_mu_; absl::flat_hash_map, DeviceCompilationClusterSignature::Hash> @@ -225,6 +230,11 @@ DeviceCompiler::DeviceCompiler( async_compiler_threads_ = std::make_unique( tensorflow::Env::Default(), "async_compiler_threads", kNumAsyncDeviceCompilerThreads); + + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); + if (flags->tf_xla_enable_dynamic_sizes) { + xla_batch_matcher_ = std::make_unique(); + } } template diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 3e8a43ce08ed58..2ac1b118d3f2ac 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/types/optional.h" +#include "tensorflow/compiler/jit/encapsulate_util.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/jit/shape_inference_helpers.h" @@ -54,9 +55,14 @@ limitations under the License. #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/dump_graph.h" - +#include "tensorflow/core/framework/tensor_shape.pb.h" namespace tensorflow { +static const absl::flat_hash_set kFailingOps = { + "Where", + // add more here +}; + const char* const kXlaCompiledKernelAttr = "_XlaCompiledKernel"; const char* const kXlaNumConstantArgsAttr = "_XlaNumConstantArgs"; const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs"; @@ -114,6 +120,30 @@ void MarkGuaranteedConstants( } } +// Helper to convert ExpressionProto to a readable string. +std::string ExprProtoToString(const ExpressionProto& e) { + switch (e.node_type_case()) { + case ExpressionProto::kConstantValue: + return std::to_string(e.constant_value()); + case ExpressionProto::kVariableId: + return absl::StrCat("Var(", e.variable_id(), ")"); + case ExpressionProto::kAddNode: + return absl::StrCat("(", ExprProtoToString(e.add_node().lhs()), " + ", + ExprProtoToString(e.add_node().rhs()), ")"); + case ExpressionProto::kSubNode: + return absl::StrCat("(", ExprProtoToString(e.sub_node().lhs()), " - ", + ExprProtoToString(e.sub_node().rhs()), ")"); + case ExpressionProto::kMulNode: + return absl::StrCat("(", ExprProtoToString(e.mul_node().lhs()), " * ", + ExprProtoToString(e.mul_node().rhs()), ")"); + case ExpressionProto::kDivNode: + return absl::StrCat("(", ExprProtoToString(e.div_node().lhs()), " / ", + ExprProtoToString(e.div_node().rhs()), ")"); + default: + return ""; + } +} + struct OutputInputTensorPairHasher { uint64 operator()(std::pair const& s) const { return Hash64Combine(OutputTensor::Hash()(s.first), @@ -369,6 +399,19 @@ class Encapsulator { namespace { +bool BuildOutputShapeProto(const Node& node, int output_slot, + TensorShapeProto* proto) { + AttrSlice attrs = node.attrs(); + auto shape_attr = + attrs.FindByString(kXlaInferredOutputTensorShapesAttrName); + if (shape_attr == nullptr || !shape_attr->has_list() || + shape_attr->list().shape_size() <= output_slot) { + return false; + } + *proto = shape_attr->list().shape(output_slot); + return true; +} + // Return in 'sorted' a topological sort of clusters according to the // dependencies encoded in ancestors. clusters is the list of all clusters // including clusters that are not present in the ancestors map. has_successors @@ -451,6 +494,31 @@ Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) { Graph* Encapsulator::Subgraph::GetGraph() const { return graph_.get(); } +void ExprToProto(xla::DynExpr* expr, ExpressionProto* proto) { + auto e = expr->s(); + if (xla::Constant* c = dynamic_cast(e)) { + proto->set_constant_value(c->get_val()); + } else if (xla::Variable* v = dynamic_cast(e)) { + proto->set_variable_id(v->get_id()); + } else if (xla::Add* a = dynamic_cast(e)) { + auto* add_msg = proto->mutable_add_node(); + ExprToProto(a->get_lhs(), add_msg->mutable_lhs()); + ExprToProto(a->get_rhs(), add_msg->mutable_rhs()); + } else if (xla::Mul* m = dynamic_cast(e)) { + auto* mul_msg = proto->mutable_mul_node(); + ExprToProto(m->get_lhs(), mul_msg->mutable_lhs()); + ExprToProto(m->get_rhs(), mul_msg->mutable_rhs()); + } else if (xla::Sub* s = dynamic_cast(e)) { + auto* sub_msg = proto->mutable_sub_node(); + ExprToProto(s->get_lhs(), sub_msg->mutable_lhs()); + ExprToProto(s->get_rhs(), sub_msg->mutable_rhs()); + } else if (xla::Div* d = dynamic_cast(e)) { + auto* div_msg = proto->mutable_div_node(); + ExprToProto(d->get_lhs(), div_msg->mutable_lhs()); + ExprToProto(d->get_rhs(), div_msg->mutable_rhs()); + } +} + absl::Status Encapsulator::Subgraph::RecordArg( const Edge* edge, const absl::flat_hash_map& node_images, @@ -470,6 +538,22 @@ absl::Status Encapsulator::Subgraph::RecordArg( DataType dtype = edge->dst()->input_type(edge->dst_input()); builder.Attr("T", dtype); builder.Attr("index", arg_index); + AttrSlice attrs = src_node->attrs(); + TensorShapeProto output_shape_proto; + if (BuildOutputShapeProto(*src_node, src_slot, &output_shape_proto)) { + VLOG(1) << "Adding following output shapes for node " << src_node->name() + << " : " << output_shape_proto.DebugString(); + builder.Attr("_output_shapes", {output_shape_proto}); + builder.Attr(kXlaInferredOutputShapesAttrName, {output_shape_proto}); + } else { + // if cluster argument is the real argument. + auto build_attr = attrs.FindByString("_dynamic_dim"); + if (build_attr) { + VLOG(1) << "Found Dynamic dimension in " << src_node->name() << ":" + << src_slot; + builder.Attr("_dynamic_dim", *build_attr); + } + } absl::Status s = builder.Finalize(&arg_def); if (!s.ok()) return s; @@ -1143,6 +1227,14 @@ static absl::Status RenumberArguments(Graph* graph, return absl::OkStatus(); } +static bool SubgraphHasFailingOps(const Graph& g) { + for (Node* n : g.op_nodes()) { + if (n->IsRetval()) continue; + if (kFailingOps.contains(n->def().op())) return true; + } + return false; +} + absl::Status EncapsulateSubgraphsPass::Run( const GraphOptimizationPassOptions& options) { VLOG(1) << "EncapsulateSubgraphsPass::Run"; @@ -1289,8 +1381,8 @@ absl::Status EncapsulateSubgraphsPass::Run( // TODO(phawkins): add a forward is-constant analysis, similarly split // outputs into host-memory constants and device-memory non-constants. - - AddNodeAttr(kXlaCompiledKernelAttr, true, node); + bool compile_enabled = !SubgraphHasFailingOps(**subgraph); + AddNodeAttr(kXlaCompiledKernelAttr, compile_enabled, node); AddNodeAttr(kXlaNumConstantArgsAttr, num_consts, node); AddNodeAttr(kXlaNumResourceArgsAttr, num_resources, node); return absl::OkStatus(); diff --git a/tensorflow/compiler/jit/encapsulate_util.cc b/tensorflow/compiler/jit/encapsulate_util.cc index fa94a341bbabc6..e0849dcf192c81 100644 --- a/tensorflow/compiler/jit/encapsulate_util.cc +++ b/tensorflow/compiler/jit/encapsulate_util.cc @@ -303,6 +303,10 @@ absl::Status PostprocessControlEdgesBetweenOutsideCompilations( } // namespace const char kXlaInferredShapesAttrName[] = "_xla_inferred_shapes"; +const char kXlaInferredOutputTensorShapesAttrName[] = + "_xla_inferred_output_tensor_shapes"; +const char kXlaInferredOutputShapesAttrName[] = + "_xla_inferred_output_shapes"; const char kXlaConnectedToXlaComputationAttrName[] = "_xla_connected_to_xla_computation"; diff --git a/tensorflow/compiler/jit/encapsulate_util.h b/tensorflow/compiler/jit/encapsulate_util.h index 7c99763c770728..25ca891529cadc 100644 --- a/tensorflow/compiler/jit/encapsulate_util.h +++ b/tensorflow/compiler/jit/encapsulate_util.h @@ -28,6 +28,15 @@ namespace tensorflow { // a list of PartialTensorShape objects. extern const char kXlaInferredShapesAttrName[]; +// Attribute marking Grappler-inferred output TensorShapeProtos. Attribute +// value is a list of TensorShapeProto objects and may include ExpressionProto +// annotations when available. +extern const char kXlaInferredOutputTensorShapesAttrName[]; + +// Attribute carrying inferred output TensorShapeProtos on encapsulated _Arg +// nodes for XlaCompileOp argument reconstruction. +extern const char kXlaInferredOutputShapesAttrName[]; + // Infers output shapes for all nodes in graph `g`. The output shapes will be // stored in node attribute `kXlaInferredShapesAttrName`. // diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index 10756ddf9de7b5..212eca3e03156f 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -108,6 +108,15 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector* flag_list) { Flag("tf_xla_max_cluster_size", &mark_for_compilation_flags->tf_xla_max_cluster_size, "Maximum number of operators in an XLA compilation."), + Flag("tf_xla_annotate_cluster_id", + &mark_for_compilation_flags->tf_xla_annotate_cluster_id, + "Allow operator names to influence clustering scheme." + "Operators whose name starting with .cluster.{id} will likely" + "to be clustered together if the ids are the same number. " + ".cluster.none will not be clustered with those having numbered id"), + Flag("tf_xla_cluster_parallel", + &mark_for_compilation_flags->tf_xla_cluster_parallel, + "Split parallel compute subgraph info different clusters"), Flag( "tf_xla_ops_to_cluster", &mark_for_compilation_flags->tf_xla_ops_to_cluster, @@ -155,6 +164,9 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector* flag_list) { &mark_for_compilation_flags->tf_xla_deterministic_cluster_names, "Causes the function names assigned by auto clustering to be " "deterministic from run to run."), + Flag("tf_xla_enable_dynamic_sizes", + &mark_for_compilation_flags->tf_xla_enable_dynamic_sizes, + "Enable dynamic sizes support."), Flag("tf_xla_persistent_cache_directory", &mark_for_compilation_flags->tf_xla_persistent_cache_directory, "If non-empty, JIT-compiled executables are saved to and loaded " @@ -175,6 +187,11 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector* flag_list) { &mark_for_compilation_flags->tf_xla_persistent_cache_prefix, "Specifies the persistance cache prefix. Default is " "\"xla_compile_cache\""), + Flag("tf_xla_threshold_for_megamorphic", + &mark_for_compilation_flags->tf_xla_threshold_for_megamorphic, + "Sets the threshold for marking a cluster megamorphic. " + "Setting it to -1 disables marking clusters megamorphic." + "Setting it to 0 uses the default behaviour of TensorFlow."), Flag("tf_xla_sparse_core_disable_table_stacking", &sparse_core_flags->tf_xla_sparse_core_disable_table_stacking, "Disable table stacking for all the tables passed to the SparseCore" @@ -232,15 +249,19 @@ void AllocateAndParseFlags() { mark_for_compilation_flags->tf_xla_min_cluster_size = 4; mark_for_compilation_flags->tf_xla_max_cluster_size = std::numeric_limits::max(); + mark_for_compilation_flags->tf_xla_annotate_cluster_id = false; + mark_for_compilation_flags->tf_xla_cluster_parallel = false; mark_for_compilation_flags->tf_xla_clustering_debug = false; mark_for_compilation_flags->tf_xla_cpu_global_jit = false; mark_for_compilation_flags->tf_xla_clustering_fuel = std::numeric_limits::max(); + mark_for_compilation_flags->tf_xla_threshold_for_megamorphic = 0; mark_for_compilation_flags ->tf_xla_disable_deadness_safety_checks_for_debugging = false; mark_for_compilation_flags ->tf_xla_disable_resource_variable_safety_checks_for_debugging = false; mark_for_compilation_flags->tf_xla_deterministic_cluster_names = false; + mark_for_compilation_flags->tf_xla_enable_dynamic_sizes = false; mark_for_compilation_flags->tf_xla_persistent_cache_directory = ""; mark_for_compilation_flags->tf_xla_persistent_cache_device_types = ""; mark_for_compilation_flags->tf_xla_persistent_cache_read_only = false; diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index 0d0c5082cf9a82..971dd8a7a38229 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -62,6 +62,12 @@ struct MarkForCompilationPassFlags { // Maximum number of operators in an XLA compilation. int32 tf_xla_max_cluster_size; + // Enable operator name to influence clustering decision + bool tf_xla_annotate_cluster_id; + + // Split parallel compute subgraph info different clusters + bool tf_xla_cluster_parallel; + // If non-empty, limit XLA clustering to the following TF operations. string tf_xla_ops_to_cluster; @@ -93,6 +99,9 @@ struct MarkForCompilationPassFlags { // so that they remain stable from run to run of auto clusteing. bool tf_xla_deterministic_cluster_names; + // If true enables support of dynamic sizes. + bool tf_xla_enable_dynamic_sizes; + // If non-empty, JIT-compiled executables are saved to and loaded from the // specified file system directory path. std::string tf_xla_persistent_cache_directory; @@ -111,6 +120,11 @@ struct MarkForCompilationPassFlags { // Specifies the persistance cache prefix. Default is "xla_compile_cache" string tf_xla_persistent_cache_prefix; + + // Sets the threshold for marking a cluster megamorphic. + // Setting it to -1 disables marking clusters megamorphic. + // Setting it to 0 uses the default behaviour of TensorFlow. + int64_t tf_xla_threshold_for_megamorphic; }; // Flags associated with XLA Sparse Core. diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index dec057ebfba9ab..8da9744bacd79e 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -56,6 +56,7 @@ cc_library( "//tensorflow/compiler/jit:tf_graph_to_hlo_compiler", "//tensorflow/compiler/jit:tf_to_hlo_compiler", "//tensorflow/compiler/jit:xla_compile_util", + "//tensorflow/compiler/jit:xla_batch_matcher", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:refcount", "@com_google_absl//absl/base:core_headers", diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 468b85280e2a47..dd6a2d41053b3b 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -42,6 +42,7 @@ limitations under the License. #include "tensorflow/compiler/jit/device_compilation_profiler.h" #include "tensorflow/compiler/jit/device_compiler.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/jit/encapsulate_util.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/pjrt_compile_util.h" #include "tensorflow/compiler/jit/variable_info.h" @@ -54,6 +55,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_host_send_device_context.h" #include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/jit/xla_platform_info.h" +#include "tensorflow/compiler/jit/xla_batch_matcher.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -65,6 +67,7 @@ limitations under the License. #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/batch_size_resource.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" @@ -379,6 +382,72 @@ GetXlaCompilerArgsAndSnapshotVariables( return result; } + +std::unique_ptr ExprFromProto(const ExpressionProto& proto) { + switch (proto.node_type_case()) { + case ExpressionProto::kConstantValue: + return DimExpr::Cons(proto.constant_value()); + case ExpressionProto::kVariableId: + return DimExpr::Var(proto.variable_id()); + case ExpressionProto::kAddNode: { + auto lhs = ExprFromProto(proto.add_node().lhs()); + auto rhs = ExprFromProto(proto.add_node().rhs()); + // Note: These are owning pointers, but ExprAdd takes raw pointers. + // The caller must manage lifetime appropriately. + return std::make_unique(lhs.release(), rhs.release()); + } + case ExpressionProto::kSubNode: { + auto lhs = ExprFromProto(proto.sub_node().lhs()); + auto rhs = ExprFromProto(proto.sub_node().rhs()); + return std::make_unique(lhs.release(), rhs.release()); + } + case ExpressionProto::kMulNode: { + auto lhs = ExprFromProto(proto.mul_node().lhs()); + auto rhs = ExprFromProto(proto.mul_node().rhs()); + return std::make_unique(lhs.release(), rhs.release()); + } + case ExpressionProto::kDivNode: { + auto lhs = ExprFromProto(proto.div_node().lhs()); + auto rhs = ExprFromProto(proto.div_node().rhs()); + return std::make_unique(lhs.release(), rhs.release()); + } + case ExpressionProto::NODE_TYPE_NOT_SET: + default: + return nullptr; + } +} + +static xla::DynExpr* DimExprToDynExpr(const DimExpr* e) { + switch (e->kind()) { + case DimExpr::Kind::kConstant: { + auto* ac = static_cast(e); + return xla::DynExpr::_(ac->value()); + } + case DimExpr::Kind::kVariable: { + auto* av = static_cast(e); + return xla::DynExpr::V(1); + } + case DimExpr::Kind::kAdd: { + auto* ee = static_cast(e); + return *DimExprToDynExpr(ee->lhs()) + *DimExprToDynExpr(ee->rhs()); + } + case DimExpr::Kind::kSub: { + auto* ee = static_cast(e); + return *DimExprToDynExpr(ee->lhs()) - *DimExprToDynExpr(ee->rhs()); + } + case DimExpr::Kind::kMul: { + auto* ee = static_cast(e); + return *DimExprToDynExpr(ee->lhs()) * *DimExprToDynExpr(ee->rhs()); + } + case DimExpr::Kind::kDiv: { + auto* ee = static_cast(e); + return *DimExprToDynExpr(ee->lhs()) / *DimExprToDynExpr(ee->rhs()); + } + } + return nullptr; +} + + absl::Status CompileToLocalExecutable( OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars, const XlaPlatformInfo& platform_info, @@ -427,9 +496,200 @@ absl::Status CompileToLocalExecutable( XlaCompiler::CompileOptions compile_options = GenerateCompileOptions(has_ref_vars, may_alias_resource_update); - return xla_device_compiler->CompileIfNeeded( - options, function, args, compile_options, compile_mode, profiler, - compilation_result, executable); + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); + if (flags->tf_xla_enable_dynamic_sizes) { + // Rewriting the argument with expressions if they have dynamic + // dimension, detecting dynamic dimension via either _dynamic_dim or the + // inferred-output-shapes attr attached during encapsulation. + std::vector norm_args(args.begin(), args.end()); + int64_t filled_batch = 0; + bool saw_dynamic_dim_value = false; + // Only supporting one dynamic dimension. + bool has_multiple_dynamic_dim_values = false; + int64_t dynamic_dim_value = 0; + XlaBatchMatcher* xla_batch_matcher = + xla_device_compiler->xla_batch_matcher(); + auto record_dynamic_dim_value = [&](int64_t dim_size) { + if (!saw_dynamic_dim_value) { + saw_dynamic_dim_value = true; + dynamic_dim_value = dim_size; + return; + } + if (dynamic_dim_value != dim_size) { + has_multiple_dynamic_dim_values = true; + } + }; + if (options.flib_def != nullptr) { + const FunctionDef* fdef = options.flib_def->Find(function.name()); + if (fdef != nullptr) { + for (const auto& kv : fdef->arg_attr()) { + int arg_index = kv.first; + const auto& attr_map = kv.second.attr(); + const std::string& node_name = + fdef->signature().input_arg(arg_index).name(); + + // Special case for _dynamic_dim... + auto dyn_dim_attr = attr_map.find("_dynamic_dim"); + if (dyn_dim_attr != attr_map.end()) { + TensorShape& shp = + std::get(norm_args[arg_index].shape); + const AttrValue& v = dyn_dim_attr->second; + int64_t idx = v.i(); + record_dynamic_dim_value(shp.dim_size(idx)); + if (!filled_batch && xla_batch_matcher) { + filled_batch = + xla_batch_matcher->get_xla_compile_batch(shp.dim_size(idx)); + } + + std::vector dyn_exprs; + for (int d : shp.dim_sizes()) { + dyn_exprs.push_back(xla::DynExpr::_(d)); + } + dyn_exprs[idx] = xla::DynExpr::V(1); + shp.set_expressions(dyn_exprs); + continue; + } + auto it = attr_map.find(kXlaInferredOutputShapesAttrName); + if (it == attr_map.end()) continue; + + const TensorShapeProto& proto = it->second.list().shape(0); + const auto& exp = proto.expressions(); + TensorShape& shp = std::get(norm_args[arg_index].shape); + + if (!filled_batch && xla_batch_matcher) { + for (int idx = 0; idx < exp.size(); ++idx) { + // Look for dynamic expression. If found then compute padding + // value and exit loop. + auto e = DimExprToDynExpr(ExprFromProto(exp[idx]).get())->s(); + if (e->is_dynamic()) { + int64_t var_value = e->solve(shp.dim_size(idx)); + if (var_value <= 0) { + LOG(WARNING) + << "Failed to solve dynamic dimension for argument " + << arg_index << " dim " << idx << " with size " + << shp.dim_size(idx) + << "; falling back to original dimension size."; + var_value = shp.dim_size(idx); + } else { + VLOG(1) << "Solved dynamic dimension from " + << shp.dim_size(idx) << " to " << var_value; + } + record_dynamic_dim_value(var_value); + filled_batch = + xla_batch_matcher->get_xla_compile_batch(var_value); + break; + } + } + } + + std::vector dyn_exprs; + for (int d : shp.dim_sizes()) { + dyn_exprs.push_back(xla::DynExpr::_(d)); + } + for (int j = 0; j < exp.size(); ++j) { + auto e = DimExprToDynExpr(ExprFromProto(exp[j]).get())->s(); + if (e->is_dynamic()) { + dyn_exprs[j] = e; + } + } + shp.set_expressions(dyn_exprs); + } + } + } + + struct SaveOldVar { + int arg_index; + int64_t dyn_dim; + int64_t old_value; + }; + std::vector old_vars; + auto maybe_rewrite_scalar_constant = [&](int arg_index) { + if (!saw_dynamic_dim_value || has_multiple_dynamic_dim_values) { + return; + } + + auto& arg = norm_args[arg_index]; + if (arg.kind != XlaCompiler::Argument::kConstant) { + return; + } + + const bool is_scalar = TensorShapeUtils::IsScalar(arg.constant_value.shape()); + const bool is_vector = TensorShapeUtils::IsVector(arg.constant_value.shape()) && + arg.constant_value.NumElements() > 0; + if (!is_scalar && !is_vector) { + return; + } + + if (arg.constant_value.dtype() == DT_INT32) { + const int32 old_value = arg.constant_value.flat()(0); + // Heuristic: rewrite only scalar constants or shape-like int vectors + // whose leading entry matches the observed runtime batch size. + if (old_value == dynamic_dim_value) { + // Deep-copy before rewrite so the compile-time patch does not mutate + // a Tensor buffer shared with caller-visible inputs. + Tensor scalar_copy(arg.constant_value.dtype(), + arg.constant_value.shape()); + scalar_copy.flat() = arg.constant_value.flat(); + arg.constant_value = std::move(scalar_copy); + arg.constant_value.flat()(0) = + static_cast(filled_batch); + } + } else if (arg.constant_value.dtype() == DT_INT64) { + const int64_t old_value = arg.constant_value.flat()(0); + // Same heuristic for int64 scalar constants. + if (old_value == dynamic_dim_value) { + Tensor scalar_copy(arg.constant_value.dtype(), + arg.constant_value.shape()); + scalar_copy.flat() = arg.constant_value.flat(); + arg.constant_value = std::move(scalar_copy); + arg.constant_value.flat()(0) = filled_batch; + } + } + }; + // We rewrite only dynamic dimensions to the padded compile batch and then + // restore the original runtime sizes after compilation. Some scalar + // constants are actually runtime batch sizes folded by earlier TF passes, + // so rewrite only those that match the detected dynamic runtime value. + // Scalar constants are deep-copied before rewrite so the change stays + // local to norm_args and does not require restoration. + if (filled_batch) { + for (int i = 0; i < norm_args.size(); ++i) { + TensorShape& shp = std::get(norm_args[i].shape); + for (int j = 0; j < shp.get_expressions().size(); ++j) { + auto e = shp.get_expression(j); + if (e->is_dynamic()) { + int64_t old = shp.dim_size(j); + old_vars.push_back({i, j, old}); + xla::DynExpr* padded_expr = xla::DynExpr::_(filled_batch); + xla::DynExpr* subst_expr = e->substitute(1, padded_expr)->s(); + int64_t new_dim = subst_expr->get_val(); + if (new_dim >= 0) { + shp.set_dim(j, new_dim); + // Necessary because set_dim removes the expression: + shp.set_expression(j, e); + } + } + } + maybe_rewrite_scalar_constant(i); + } + } + auto status = xla_device_compiler->CompileIfNeeded( + options, function, norm_args, compile_options, compile_mode, profiler, + compilation_result, executable); + // Restore the original runtime dimensions after compilation. + if (filled_batch) { + for (const auto& old_var : old_vars) { + TensorShape& shp = + std::get(norm_args[old_var.arg_index].shape); + shp.set_dim(old_var.dyn_dim, old_var.old_value); + } + } + return status; + } else { + return xla_device_compiler->CompileIfNeeded( + options, function, args, compile_options, compile_mode, profiler, + compilation_result, executable); + } } absl::Status GetUpdatedVariables( @@ -802,14 +1062,24 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { ctx, function_, has_ref_vars_, platform_info_, args, compile_mode, /*may_alias_resource_update=*/false, &client, &kernel, &executable); } + if (compile_mode != DeviceCompileMode::kLazy || status.code() != error::UNIMPLEMENTED) { - OP_REQUIRES_OK(ctx, status); + if ((status != OkStatus()) && + (status.code() != error::UNIMPLEMENTED) && + (compile_mode == DeviceCompileMode::kLazy)) { + // We set the error to error::UNIMPLEMENTED so it falls in the + // conditions of the if to fall back to TensorFlow function call + status = tensorflow::errors::Unimplemented(status.ToString()); + } else { + OP_REQUIRES_OK(ctx, status); + } } if (status.code() == error::UNIMPLEMENTED) { - LOG(WARNING) << "Compilation failed:" << status - << ". Falling back to TF function call."; + LOG(WARNING) << "[HUAWEI] Compilation of the cluster failed with:"; + LOG(WARNING) << "[HUAWEI] " << status; + LOG(WARNING) << "[HUAWEI] Falling back to TF function call.\n"; BroadcastOptimizationRemark( XlaOptimizationRemark::UNIMPLEMENTED_OPERATION, status.ToString()) @@ -817,6 +1087,8 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { executable = nullptr; pjrt_executable = nullptr; mutex_lock guard(cannot_compile_cluster_mu_); + // TODO: decide if we want to set this flag to true, as we may want to + // allow the cluster to try to compile again later in time. cannot_compile_cluster_ = true; } } @@ -871,7 +1143,6 @@ XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) void XlaRunOp::Compute(OpKernelContext* ctx) { VLOG(3) << "XlaRunOp " << def().name(); Tensor key_tensor = ctx->input(ctx->num_inputs() - 1); - bool use_pjrt = GetXlaOpsCommonFlags() ->tf_xla_use_device_api.IsEnabledInXlaCompileAndRunForDevice( @@ -953,6 +1224,69 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { xla::ExecutableRunOptions run_options; + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); + if (flags->tf_xla_enable_dynamic_sizes) { + bool is_set = false; + std::set dyn_vals; + const auto* comp_result = closure.compilation_result(); + const int num_constant_args = closure.num_constant_args(); + for (int i = 0; i < comp_result->xla_input_shapes.size(); i++) { + const auto& xla_shape = closure.compilation_result()->xla_input_shapes[i]; + if (!xla_shape.IsArray() || xla_shape.expressions().empty()) continue; + + for (int dim = 0; dim < xla_shape.expressions().size(); dim++) { + xla::DynExpr* expr = xla_shape.expressions(dim); + if (expr && expr->is_dynamic()) { + int input_idx = comp_result->input_mapping[i] - num_constant_args; + if (input_idx < 0 || input_idx >= ctx->num_inputs()) { + VLOG(1) << "Warning: Input index is out of range"; + continue; + } + VLOG(1) << "input shape is " << ctx->input(input_idx).shape() + << ", corresponding xla input shape is " << xla_shape; + int64_t size = ctx->input(input_idx).shape().dim_size(dim); + int64_t dyn_val = expr->solve(size); // TODO: check if the result is correct later. + VLOG(1) << "Found dynamic input. Real size is: " << size + << ", solved dynamic value is " << dyn_val; + if (dyn_val == -1) { + VLOG(1) << "Warning: Failed to solve the expression"; + continue; + } + dyn_vals.insert(dyn_val); + } + } + } + + if (dyn_vals.size() == 1) { + run_options.set_batch_size(*(dyn_vals.begin())); + is_set = true; + } else { + // Found multiple variables + VLOG(1) << "Warning: Found multiple variables"; + } + + if (!is_set) { + // TODO: Fallback to BatchSizeResource for now. Remove it later. + BatchSizeResource* bsr = nullptr; + ScopedStepContainer* step_container = ctx->step_container(); + + absl::Status st = step_container->Lookup( + ctx->resource_manager(), BatchSizeResourceName, &bsr); + + if (st.ok()) { + run_options.set_batch_size(bsr->GetBatchSize()); + VLOG(1) << "run_options.batch_size is set to: " + << run_options.batch_size() << ". step_id: " << ctx->step_id(); + bsr->Unref(); + + } else if (IsNotFound(st)) { + VLOG(1) << "Warning: Not found BatchSizeResource in step_container."; + } else { + OP_REQUIRES_OK(ctx, st); + } + } + } + // Host callbacks used for HLO send/recv. xla::SendDeviceMemoryFunction send_function = GetSendDeviceMemoryFunction(ctx, key); @@ -981,7 +1315,8 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { launch_context.PopulateOutputs( ctx, closure.compilation_result(), execution_output->ConsumeResult(), /*missing_ctx_input_prefix=*/closure.num_constant_args(), - absl::MakeSpan(*variable_infos), input_output_alias, snapshot_ptrs)); + absl::MakeSpan(*variable_infos), input_output_alias, snapshot_ptrs, + &run_options)); } XlaMergeOp::XlaMergeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index c3a24f3e0f7163..92c21546497efb 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -23,12 +23,14 @@ limitations under the License. #include #include #include +#include #include #include #include #include #include #include +#include #include "absl/base/call_once.h" #include "absl/container/flat_hash_map.h" @@ -39,6 +41,7 @@ limitations under the License. #include "tensorflow/compiler/jit/deadness_analysis.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/device_util.h" +#include "tensorflow/compiler/jit/encapsulate_util.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" @@ -68,6 +71,9 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/dump_graph.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/grappler_item.h" namespace tensorflow { @@ -109,6 +115,8 @@ class MarkForCompilationPassImpl { // stable from run to rum. bool deterministic_cluster_names; + bool enable_dynamic_sizes; + int max_cluster_size; int min_cluster_size; @@ -123,6 +131,11 @@ class MarkForCompilationPassImpl { std::atomic* fuel; bool dump_graphs; + + // Enable models to influcence clustering with operator names + int annotate_cluster_id; + + bool enable_cluster_parallel; }; MarkForCompilationPassImpl(DebugOptions debug_options, Graph* graph, @@ -245,7 +258,17 @@ class MarkForCompilationPassImpl { " others #", cycles_graph_node_id(), ">"); } + int annotated_id() const { return annotated_id_; } + void set_annotated_id(int id) { annotated_id_ = id; } + int chain_id() const {return chain_id_;} + void set_chain_id(int id) {chain_id_ = id;} + void add_dim_var(int dim_var) { dim_vars_.insert(dim_var); } + const std::set& dim_vars() const { return dim_vars_; } + private: + int annotated_id_ = -1; + std::set dim_vars_; + int chain_id_ = -1; int cluster_size_ = 1; int cycles_graph_node_id_; int effective_cluster_size_; @@ -317,6 +340,17 @@ class MarkForCompilationPassImpl { return compilation_candidates_.find(n) != compilation_candidates_.end(); } + absl::Status AssignAnnotatedClusterIDs(); + absl::Status AssignDimVars(); + void collectInputNodes(std::set &path_nodes); + void collectMergeNodes(const std::vector& nodeSet, + std::set &merger_nodes); + void collectPathNodes(Node* start, std::set &path_nodes, + std::set& merger_nodes); + std::map> collectParallelNode( + const std::vector& nodeSet); + absl::Status AssignParallelChains(); + // Tries to contract the edge from cluster `from` to cluster `to`. Returns // true if successful. absl::StatusOr TryToContractEdge(Cluster* from, Cluster* to); @@ -651,6 +685,194 @@ absl::Status IgnoreResourceOpForSafetyAnalysis( } return absl::OkStatus(); } +// node mapping to multiple vectors of expressions (one for each output in +// order) +static std::map>>> + expr_map; +// Helper to convert ExpressionProto to a readable string. +std::string ExprProtoToString(const ExpressionProto& e) { + switch (e.node_type_case()) { + case ExpressionProto::kConstantValue: + return std::to_string(e.constant_value()); + case ExpressionProto::kVariableId: + return absl::StrCat("Var(", e.variable_id(), ")"); + case ExpressionProto::kAddNode: + return absl::StrCat("(", ExprProtoToString(e.add_node().lhs()), " + ", + ExprProtoToString(e.add_node().rhs()), ")"); + case ExpressionProto::kSubNode: + return absl::StrCat("(", ExprProtoToString(e.sub_node().lhs()), " - ", + ExprProtoToString(e.sub_node().rhs()), ")"); + case ExpressionProto::kMulNode: + return absl::StrCat("(", ExprProtoToString(e.mul_node().lhs()), " * ", + ExprProtoToString(e.mul_node().rhs()), ")"); + case ExpressionProto::kDivNode: + return absl::StrCat("(", ExprProtoToString(e.div_node().lhs()), " / ", + ExprProtoToString(e.div_node().rhs()), ")"); + default: + return ""; + } +} + +std::unique_ptr ExprFromProto(const ExpressionProto& proto) { + switch (proto.node_type_case()) { + case ExpressionProto::kConstantValue: + return DimExpr::Cons(proto.constant_value()); + case ExpressionProto::kVariableId: + return DimExpr::Var(proto.variable_id()); + case ExpressionProto::kAddNode: { + auto lhs = ExprFromProto(proto.add_node().lhs()); + auto rhs = ExprFromProto(proto.add_node().rhs()); + // Note: These are owning pointers, but ExprAdd takes raw pointers. + // The caller must manage lifetime appropriately. + return std::make_unique(lhs.release(), rhs.release()); + } + case ExpressionProto::kSubNode: { + auto lhs = ExprFromProto(proto.sub_node().lhs()); + auto rhs = ExprFromProto(proto.sub_node().rhs()); + return std::make_unique(lhs.release(), rhs.release()); + } + case ExpressionProto::kMulNode: { + auto lhs = ExprFromProto(proto.mul_node().lhs()); + auto rhs = ExprFromProto(proto.mul_node().rhs()); + return std::make_unique(lhs.release(), rhs.release()); + } + case ExpressionProto::kDivNode: { + auto lhs = ExprFromProto(proto.div_node().lhs()); + auto rhs = ExprFromProto(proto.div_node().rhs()); + return std::make_unique(lhs.release(), rhs.release()); + } + case ExpressionProto::NODE_TYPE_NOT_SET: + default: + return nullptr; + } +} + +static xla::DynExpr* DimExprToDynExpr(const DimExpr* e) { + switch (e->kind()) { + case DimExpr::Kind::kConstant: { + auto* ac = static_cast(e); + return xla::DynExpr::_(ac->value()); + } + case DimExpr::Kind::kVariable: { + auto* av = static_cast(e); + return xla::DynExpr::V(av->id()); // Use 1 all the time for now + } + case DimExpr::Kind::kAdd: { + auto* ee = static_cast(e); + return *DimExprToDynExpr(ee->lhs()) + *DimExprToDynExpr(ee->rhs()); + } + case DimExpr::Kind::kSub: { + auto* ee = static_cast(e); + return *DimExprToDynExpr(ee->lhs()) - *DimExprToDynExpr(ee->rhs()); + } + case DimExpr::Kind::kMul: { + auto* ee = static_cast(e); + return *DimExprToDynExpr(ee->lhs()) * *DimExprToDynExpr(ee->rhs()); + } + case DimExpr::Kind::kDiv: { + auto* ee = static_cast(e); + return *DimExprToDynExpr(ee->lhs()) / *DimExprToDynExpr(ee->rhs()); + } + } + return nullptr; +} + +// Runs Grappler static inference and logs any ExpressionProto found in output +// tensor shapes (from GraphProperties, not from _output_shapes attrs). +void LogExpressionsViaGraphProperties(tensorflow::Graph& graph) { + using tensorflow::ExpressionProto; + using tensorflow::GraphDef; + using tensorflow::NodeDef; + using tensorflow::TensorShapeProto; + using tensorflow::grappler::GraphProperties; + using tensorflow::grappler::GrapplerItem; + + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + auto node_name_index = graph.BuildNodeNameIndex(); + + GrapplerItem item; + item.id = "mark_for_compilation_pass_expr_dump"; + item.graph = graph_def; + + GraphProperties props(item); + + absl::Status st = props.InferStatically( + /*assume_valid_feeds=*/false, + /*aggressive_shape_inference=*/false, + /*include_input_tensor_values=*/false, + /*include_output_tensor_values=*/false); + + if (!st.ok()) { + LOG(ERROR) << "[EXPR][GP] InferStatically failed: " << st.message(); + return; + } + + int found = 0; + VLOG(1) << "[EXPR][GP] === GraphProperties output expr dump ==="; + + auto convert_graph_properties_shape = [](const TensorShapeProto& gp_shape) { + TensorShapeProto out; + out.set_unknown_rank(gp_shape.unknown_rank()); + for (const auto& dim : gp_shape.dim()) { + out.add_dim()->set_size(dim.size()); + ExpressionProto* expr = out.add_expressions(); + if (dim.expr().node_type_case() != ExpressionProto::NODE_TYPE_NOT_SET) { + *expr = dim.expr(); + } else { + expr->set_constant_value(dim.size()); + } + } + return out; + }; + + for (const NodeDef& n : graph_def.node()) { + if (!props.HasOutputProperties(n.name())) continue; + const auto& outs = props.GetOutputProperties(n.name()); + std::vector inferred_output_shapes; + inferred_output_shapes.reserve(outs.size()); + std::vector>> list_exprs(outs.size()); + for (int out_idx = 0; out_idx < static_cast(outs.size()); ++out_idx) { + const auto& tp = outs[out_idx]; + const TensorShapeProto& shp = tp.shape(); + inferred_output_shapes.push_back(convert_graph_properties_shape(shp)); + + std::vector> exprs; + for (int d = 0; d < shp.dim_size(); ++d) { + const auto& dim = shp.dim(d); + + const ExpressionProto& expr = dim.expr(); + if (expr.node_type_case() == ExpressionProto::NODE_TYPE_NOT_SET) + continue; + + VLOG(1) << "Node " << n.name() << " has expression " + << ExprProtoToString(expr); + + auto ex = ExprFromProto(expr); + exprs.push_back(std::move(ex)); + + ++found; + } + if (shp.dim_size() == 0 && shp.unknown_rank()) { + // Add two dummy variables to represent the unknown rank + exprs.push_back(std::make_unique(-888)); + exprs.push_back(std::make_unique(-889)); + } + + list_exprs[out_idx] = std::move(exprs); + } + expr_map[n.name()] = std::move(list_exprs); + auto node_it = node_name_index.find(n.name()); + if (node_it != node_name_index.end()) { + node_it->second->AddAttr(kXlaInferredOutputTensorShapesAttrName, + inferred_output_shapes); + } + + } + + VLOG(1) << "[EXPR][GP] === Found " << found + << " expressions via GraphProperties ==="; +} absl::StatusOr MarkForCompilationPassImpl::Initialize() { TF_RET_CHECK(!initialized_ && !edges_contracted_ && !clusters_created_); @@ -686,6 +908,19 @@ absl::StatusOr MarkForCompilationPassImpl::Initialize() { // representative names the node in the 'cycles' graph that represents the // cluster. TF_RETURN_IF_ERROR(BuildInitialClusterSet()); + + // Source model may be annotated with preferred clusters. This function + // just interpreter the annotations and assign preferred IDs + if (debug_options_.annotate_cluster_id) { + TF_RETURN_IF_ERROR(AssignAnnotatedClusterIDs()); + } + if (debug_options_.enable_dynamic_sizes) { + LogExpressionsViaGraphProperties(*graph_); + TF_RETURN_IF_ERROR(AssignDimVars()); + } + if (debug_options_.enable_cluster_parallel) { + TF_RETURN_IF_ERROR(AssignParallelChains()); + } return true; } @@ -893,7 +1128,12 @@ absl::Status MarkForCompilationPassImpl::RunEdgeContractionLoop() { ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) { return TryToContractEdge(from, to); })); + /* Clustering conditions for dynamic shapes may break the assumption of + * fixed point at phase 2, so this check may fail. + * Now just disable this check, and we can re-enable it after we have more + * confidence in the code. TF_RET_CHECK(!changed); + */ return absl::OkStatus(); } @@ -1028,6 +1268,7 @@ absl::Status MarkForCompilationPassImpl::CreateClusters() { // trouble. if (cluster->effective_cluster_size() >= debug_options_.min_cluster_size || + cluster->chain_id() != -1 || cluster->has_functional_control_flow() || cluster->is_xla_compile_attr_true()) { string& name = cluster_names[cluster->cycles_graph_node_id()]; @@ -1548,14 +1789,303 @@ bool MarkForCompilationPassImpl::CompilationDisallowedByXlaCompileAttr( return false; } +absl::Status MarkForCompilationPassImpl::AssignAnnotatedClusterIDs(void) { + VLOG(1) << "Run AssignAnnotatedClusterIDs"; + + for (Node* node : graph_->nodes()) { + Cluster * cluster = GetClusterForNode(node); + auto name = node->name(); + if (cluster) { + std::string pat = "^\\.cluster\\.(\\d+|none)"; + std::regex idPattern(pat); + std::smatch matched; + if (std::regex_search(name, matched, idPattern)) { + auto m = matched.str(1); + if (m == "none") { + // Prefer not to cluster + VLOG(1) << name << " : Default annotated cluster id -1"; + cluster->set_annotated_id(-1); + } + else { + try { + int id = std::stoi(m); + cluster->set_annotated_id(id); + VLOG(1) << name << " : Set annotated cluster id " << m; + } + catch (...) { + VLOG(1) << name << " : Invalid cluster id: " << m; + } + } + } + else { + VLOG(1) << "Not matched: " << name << " pattern is " << pat; + } + } + else { + VLOG(1) << name << ": Not initially clustered"; + } + } + return absl::OkStatus(); +} + +absl::Status MarkForCompilationPassImpl::AssignDimVars(void) { + for (Node* node : graph_->nodes()) { + auto node_name = node->name(); + Cluster * cluster = GetClusterForNode(node); + if (!cluster) continue; + for (const tensorflow::Edge* edge : node->in_edges()) { + if (edge->IsControlEdge()) { + // Skip control edges if you are only interested in data edges + continue; + } + + const tensorflow::Node* input = edge->src(); // Source node of the edge + auto it = expr_map.find(input->name()); + if (it == expr_map.end()) { + VLOG(2) << "No expression found for node " << input->name(); + continue; + } + + auto output_index = edge->src_output(); // Output index of the source node + if (output_index >= (it->second).size()) { + LOG(INFO) << "Warning: Output index " << output_index << " is out of bounds for node " << input->name(); + continue; + } + for (auto& pDim: (it->second)[output_index]) { + DimExpr * d= pDim.get(); + xla::DynExpr * dyn = DimExprToDynExpr(d); + auto new_ids = dyn->get_all_ids(); + for (auto id : new_ids) { + cluster->add_dim_var(id); + VLOG(2) << "Add dim var " << id << " to cluster of node "<< node_name; + } + } + } + // create a for loop for each dim vars in cluster and print each dim var + if (VLOG_IS_ON(2)) { + if (cluster->dim_vars().empty()) { + VLOG(2) << "Cluster of node " << node_name << " has no dim vars."; + } + else { + std::string id_str; + for (auto id : cluster->dim_vars()) { + id_str += "Dim var " + std::to_string(id) + ", "; + } + VLOG(2) << "Cluster of node " << node_name << " has dim vars:\n" << id_str; + } + } + } + return absl::OkStatus(); +} + bool MarkForCompilationPassImpl::LogNotContractableAndReturnFalse( Cluster* from, Cluster* to, absl::string_view reason) { VLOG(3) << EdgeContractionFailureMsg(from, to, reason); return false; } +void MarkForCompilationPassImpl::collectInputNodes(std::set &path_nodes) { + std::unordered_map out_degree_count; + + // 4. Initialize the queue and add nodes from path_nodes + std::queue queue; + for (auto node : path_nodes) { + queue.push(node); + } + + // 5. BFS search + while (!queue.empty()) { + auto u = queue.front(); + queue.pop(); + + // Traverse all predecessor nodes + for (const Edge* e : u->in_edges()) { + Node* p = e->src(); + if (path_nodes.find(p) != path_nodes.end() || + !IsCompilationCandidate(p)) { + continue; + } + if (out_degree_count.count(p) == 0) { + // Initialize the out-degree count for the node + out_degree_count[p] = p->out_edges().size(); + } + // Decrease the out-degree count for the predecessor node + out_degree_count[p] -= 1; + + // If the predecessor node's out-degree count is 0 and not in path_nodes, + // add it to path_nodes and queue + if (out_degree_count[p] == 0) { + path_nodes.insert(p); + queue.push(p); + VLOG(3) << p->DebugString(); + } + } + } +} + +void MarkForCompilationPassImpl::collectMergeNodes( + const std::vector& nodeSet, std::set &merger_nodes) { + // 1. Collect the number of nodeSet that can reach each node + std::map> reach_map; + for (Node* start : nodeSet) { + std::set visited; + std::vector stack = {start}; + while (!stack.empty()) { + Node* cur = stack.back(); + stack.pop_back(); + if (visited.count(cur)) continue; + visited.insert(cur); + for (const Edge* e : cur->out_edges()) { + Node* next = e->dst(); + if (!visited.count(next)) + stack.push_back(next); + } + } + reach_map[start] = std::move(visited); + } + + // 2. Determine the merger node + std::map node_reach_count; + std::set all_nodes; + for (const auto& kv : reach_map) { + for (Node* n : kv.second) { + node_reach_count[n]++; + all_nodes.insert(n); + } + } + for (Node* n : all_nodes) { + // Condition 1: Reached by multiple sources + if (node_reach_count[n] >= 2) merger_nodes.insert(n); + // Condition 2: No output edges + if (n->out_edges().empty()) merger_nodes.insert(n); + } +} + +void MarkForCompilationPassImpl::collectPathNodes( + Node* start, std::set &path_nodes, std::set& merger_nodes) { + std::vector stack = {start}; + std::set visited; + + while (!stack.empty()) { + Node* cur = stack.back(); + stack.pop_back(); + if (visited.count(cur) || !IsCompilationCandidate(cur)) { + continue; + } + visited.insert(cur); + + // Stop search met the merger node + if (merger_nodes.count(cur)) { + if (cur->out_edges().empty()) path_nodes.insert(cur); + continue; + } + path_nodes.insert(cur); + + for (const Edge* e : cur->out_edges()) { + Node* next = e->dst(); + if (!visited.count(next)) { + stack.push_back(next); + } + } + } +} + +// collectParallelNode +// Search the serial merger nodes based on the parallel matmul starting points +// Search along the output edge to get the boundary from start to all merger points +// Search along the input edge to get the entire parallel computation graph +std::map> +MarkForCompilationPassImpl::collectParallelNode( + const std::vector& nodeSet) { + std::set merger_nodes; + collectMergeNodes(nodeSet, merger_nodes); + + // Collect path nodes + std::map> result; + for (Node* start : nodeSet) { + VLOG(4) << "Search parallel graph form node: " << start->DebugString(); + std::set path_nodes; + + // Search along output edge form start to merger nodes + collectPathNodes(start, path_nodes, merger_nodes); + + VLOG(4) << "Collect path nodes:"; + for (auto node : path_nodes) { + VLOG(4) << node->type_string(); + } + + VLOG(4) << "Collect input nodes:"; + // search along input edge + collectInputNodes(path_nodes); + + result[start] = std::vector(path_nodes.begin(), path_nodes.end()); + } + return result; +} + +// Collect parallel matmuls as input nodes into nodeSet +// Use collectParallelNode to get the parallel subgraph +// Mark parallel nodes change ID +absl::Status MarkForCompilationPassImpl::AssignParallelChains() { + VLOG(4) << "Run AssignParallelChains"; + // Record the matmuls that can be paralleled + std::vector> parallel_matmuls; + int minParallelMatmulNum = 2; + int next_chain_id = 0; + // Collect matmul nodes with shared input to parallel_matmuls + for (Node* node : graph_->nodes()) { + if (node->out_edges().size() < 2) continue; + std::vector matmul_nodes; + for (const Edge* e : node->out_edges()) { + if (e->IsControlEdge()) continue; + Node* succ = e->dst(); + VLOG(4) << "Find matmul node: " << succ->type_string() << " : " << succ->DebugString(); + if (succ->type_string() == "MatMul") + matmul_nodes.push_back(succ); + } + if (matmul_nodes.size() >= minParallelMatmulNum) + parallel_matmuls.push_back(matmul_nodes); + } + + for (auto matmul_nodes : parallel_matmuls) { + VLOG(4) << "Process matmul nodes: Total " << matmul_nodes.size() + << " sub matmuls"; + bool visited = false; + for (auto matmul : matmul_nodes) { + VLOG(4) << matmul->name(); + Cluster* cluster = GetClusterForNode(matmul); + if (!cluster || cluster->chain_id() != -1) { + visited = true; + break; + } + } + if (visited) { + VLOG(4) << "stop collect: has visited matmul"; + continue; + } + + std::map> subgraphMap = + collectParallelNode(matmul_nodes); + for (auto it : subgraphMap) { + auto nodeSet = it.second; + VLOG(4) << "One of Parallel sub-graph is: "; + for (auto node : nodeSet) { + VLOG(4) << node->DebugString(); + Cluster* cluster = GetClusterForNode(node); + cluster->set_chain_id(next_chain_id); + } + next_chain_id++; + } + } + return absl::OkStatus(); +} + absl::StatusOr MarkForCompilationPassImpl::TryToContractEdge( Cluster* from, Cluster* to) { + if (from->chain_id() != to->chain_id()) { + return LogNotContractableAndReturnFalse( + from, to, "nodes are in different parallel chains"); + } DCHECK(from->deadness_predicate().has_value() == to->deadness_predicate().has_value()); if (from->deadness_predicate() != to->deadness_predicate()) { @@ -1569,6 +2099,38 @@ absl::StatusOr MarkForCompilationPassImpl::TryToContractEdge( return false; } + if (debug_options_.annotate_cluster_id && from->annotated_id() != to->annotated_id()) { + return LogNotContractableAndReturnFalse( + from, to, "the two nodes do not have same annotated ids"); + } + + if (debug_options_.enable_dynamic_sizes) { + if (from->dim_vars().size() > 1 || to->dim_vars().size() > 1) { + std::string from_str = "from_vars: "; + for (auto id : from->dim_vars()) { + from_str += std::to_string(id) + ", "; + } + std::string to_str = "to_vars: "; + for (auto id : to->dim_vars()) { + to_str += std::to_string(id) + ", "; + } + return LogNotContractableAndReturnFalse( + from, to, absl::StrCat("the two nodes have multiple dynamic dimensions: ", + from_str, " and ", to_str)); + } + if (from->dim_vars().size() == 1 && to->dim_vars().size() == 1 && + from->dim_vars() != to->dim_vars()) { + return LogNotContractableAndReturnFalse( + from, to, + absl::StrCat("the two nodes have different dynamic dimensions: ", + from->dim_vars().size() == 1 + ? std::to_string(*from->dim_vars().begin()) : "none", + " and ", + to->dim_vars().size() == 1 + ? std::to_string(*to->dim_vars().begin()) : "none")); + } + } + TF_ASSIGN_OR_RETURN(bool devices_compatible, AreDevicesCompatible(*from, *to)); if (!devices_compatible) { @@ -1639,7 +2201,6 @@ absl::Status MarkForCompilationPassImpl::Run() { // MarkForCompilationPassImpl is not set up to run the subsequent phases. return absl::OkStatus(); } - TF_RETURN_IF_ERROR(RunEdgeContractionLoop()); TF_RETURN_IF_ERROR(DeclusterNodes()); TF_RETURN_IF_ERROR(CreateClusters()); @@ -1958,10 +2519,14 @@ absl::Status MarkForCompilationPass::Run( debug_options.ignore_xla_compile_attr = false; debug_options.deterministic_cluster_names = flags->tf_xla_deterministic_cluster_names; + debug_options.enable_dynamic_sizes = + flags->tf_xla_enable_dynamic_sizes; debug_options.max_cluster_size = flags->tf_xla_max_cluster_size; debug_options.min_cluster_size = flags->tf_xla_min_cluster_size; debug_options.fuel = GetPointerToFuel(flags->tf_xla_clustering_fuel); debug_options.dump_graphs = flags->tf_xla_clustering_debug; + debug_options.annotate_cluster_id = flags->tf_xla_annotate_cluster_id; + debug_options.enable_cluster_parallel = flags->tf_xla_cluster_parallel; return MarkForCompilation(options, debug_options); } @@ -1977,10 +2542,12 @@ absl::Status MarkForCompilationPass::RunForTest( flags->tf_xla_disable_resource_variable_safety_checks_for_debugging; debug_options.ignore_xla_compile_attr = true; debug_options.deterministic_cluster_names = deterministic_cluster_names; + debug_options.enable_dynamic_sizes = false; debug_options.max_cluster_size = flags->tf_xla_max_cluster_size; debug_options.min_cluster_size = flags->tf_xla_min_cluster_size; debug_options.fuel = GetPointerToFuel(flags->tf_xla_clustering_fuel); debug_options.dump_graphs = flags->tf_xla_clustering_debug; + debug_options.annotate_cluster_id = flags->tf_xla_annotate_cluster_id; return MarkForCompilation(options, debug_options); } diff --git a/tensorflow/compiler/jit/shape_inference.cc b/tensorflow/compiler/jit/shape_inference.cc index 7a5106aa69bbbf..5f270d701a9c44 100644 --- a/tensorflow/compiler/jit/shape_inference.cc +++ b/tensorflow/compiler/jit/shape_inference.cc @@ -49,10 +49,17 @@ absl::Status ShapeHandleToTensorShape( if (!context->RankKnown(handle)) return absl::OkStatus(); std::vector dims(context->Rank(handle)); + std::vector dyn_exprs(context->Rank(handle)); for (int32_t i = 0, end = dims.size(); i < end; ++i) { dims[i] = context->Value(context->Dim(handle, i)); + auto ratio = context->DynamicRatio(context->Dim(handle, i)); + dyn_exprs[i] = ratio > 0 ? (ratio * *xla::DynExpr::V(1))->s() + : xla::DynExpr::_(dims[i]); // For now } - return PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape); + auto status = + PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape); + shape->set_expressions(dyn_exprs); + return status; } absl::Status PropagateShapes( diff --git a/tensorflow/compiler/jit/xla_batch_matcher.cc b/tensorflow/compiler/jit/xla_batch_matcher.cc new file mode 100644 index 00000000000000..3cea1b8a3d11ba --- /dev/null +++ b/tensorflow/compiler/jit/xla_batch_matcher.cc @@ -0,0 +1,185 @@ +#include "tensorflow/compiler/jit/xla_batch_matcher.h" +#include "xla/debug_options_flags.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +XlaBatchMatcher::XlaBatchMatcher() { + const std::string xla_compile_batch_sizes = + xla::GetDebugOptionsFromFlags().xla_compile_batch_sizes(); + env_str_ = xla_compile_batch_sizes.c_str(); + parse_env_config(); +} + +// Trim whitespace (spaces/tabs) from both ends of a string +std::string trim(const std::string& s) { + size_t start = s.find_first_not_of(" \t"); + size_t end = s.find_last_not_of(" \t"); + return (start == std::string::npos) ? "" : s.substr(start, end - start + 1); +} + +std::vector XlaBatchMatcher::parse_single_item(const std::string& item) { + std::vector batch_list; + if (item.empty()) return batch_list; + + // 1. Parse single value (no colon separator) + if (item.find(':') == std::string::npos) { + // Validate all characters are digits (reject non-numeric chars) + for (char c : item) { + if (!isdigit(c)) { + throw std::invalid_argument("Non-numeric characters: " + item); + } + } + auto val = static_cast(std::stoll(item)); + if (val <= 0 || val > kMaxBatch) { + throw std::invalid_argument("Out of valid range (1-" + std::to_string(kMaxBatch) + "): " + item); + } + batch_list.push_back(val); + return batch_list; + } + + // 2. Parse range format (start:end:step) + std::stringstream ss(item); + std::string part; + std::vector parts; + while (std::getline(ss, part, ':')) { + parts.push_back(trim(part)); + } + if (parts.size() > 3) { + throw std::invalid_argument("Invalid range format (requires start:end:step): " + item); + } + + // Convert and validate range parameters + int64_t start, end, step; + try { + start = std::stoi(parts[0]); + end = std::stoi(parts[1]); + step = parts.size() == 2 ? 1 : std::stoi(parts[2]); + } catch (...) { + throw std::invalid_argument("Invalid numeric values in range: " + item); + } + if (start <= 0 || end <= 0 || step <= 0) { + throw std::invalid_argument("Range parameters must be positive integers: " + item); + } + if (start > end) { + throw std::invalid_argument("Start value > end value in range: " + item); + } + if (end > kMaxBatch) { + throw std::invalid_argument("Range exceeds max limit (" + std::to_string(kMaxBatch) + "): " + item); + } + + // Generate batch list from range + for (int64_t i = start; i <= end; i += step) { + batch_list.push_back(i); + } + return batch_list; +} + +void XlaBatchMatcher::print_all_batches() { + std::ostringstream oss; + oss << "[XLA_BATCH_INFO] Valid batch list update: "; + for (size_t i = 0; i < all_batches_.size(); ++i) { + if (i > 0) oss << ", "; + oss << all_batches_[i]; + } + LOG(INFO) << oss.str(); +} + +// Parse environment variable config into deduplicated, sorted batch list +// For example, export XLA_COMPILE_BATCH_SIZES="10:100:10, 977" +void XlaBatchMatcher::parse_env_config() { + // If the env var not set or is empty, filled with the nearest power of two by default + if (!env_str_) { + VLOG(2) << "[XLA_BATCH_WARN] Env var " << "--tf_xla_compile_batch_sizes" << + " not set, filled with the nearest power of two by default"; + return; + } + + if (std::string(env_str_).empty()) { + VLOG(2) << "[XLA_BATCH_WARN] Env var " << "--tf_xla_compile_batch_sizes" << + "is empty, filled with the nearest power of two by default"; + return; + } + + // Split config by commas + std::stringstream ss(env_str_); + std::string item; + while (std::getline(ss, item, ',')) { + std::string trimmed_item = trim(item); + if (trimmed_item.empty()) continue; + + // Parse single item (skip on failure to avoid breaking other items) + try { + std::vector item_batches = parse_single_item(trimmed_item); + all_batches_.insert(all_batches_.end(), item_batches.begin(), item_batches.end()); + } catch (const std::exception& e) { + LOG(INFO) << "[XLA_BATCH_WARN] Failed to parse config item, skipping: " << + trimmed_item << " (" << e.what() << ")"; + } + } + + if (!all_batches_.empty()) { + std::sort(all_batches_.begin(), all_batches_.end()); + auto last = std::unique(all_batches_.begin(), all_batches_.end()); + all_batches_.erase(last, all_batches_.end()); + } + + // Print parsed result + if (!all_batches_.empty()) print_all_batches(); + return; +} + +// Calculate the smallest power of two greater than the real batch +static int64_t GetNextPowerOfTwo(int64_t real_batch) { + // If real_batch is already a power of two, return real_batch directly + if ((real_batch & (real_batch - 1)) == 0) { + return real_batch; + } + + int64_t power = 1; + while (power < real_batch) { + power <<= 1; + } + return power; +} + +int64_t XlaBatchMatcher::find_min_larger_batch(int64_t real_batch) { + if (real_batch <= 0 || real_batch > kMaxBatch) { + LOG(INFO) << "[XLA_BATCH_WARN] Out of valid range: " << real_batch; + return real_batch; + } + + if (all_batches_.empty()) { + // Return the next power of two directly without modifying all_batches_ + return GetNextPowerOfTwo(real_batch); + } + + // Edge case 1: Real value < the smallest batch, use smallest + if (real_batch < all_batches_.front()) { + return all_batches_.front(); + } + // Edge case 2: Real value ≥ the largest batch, use the nearest power of two + if (real_batch > all_batches_.back()) { + int64_t val = GetNextPowerOfTwo(real_batch); + all_batches_.emplace_back(val); + print_all_batches(); + return val; + } + + // Find first batch larger than real value (binary search via lower_bound) + auto it = std::lower_bound(all_batches_.begin(), all_batches_.end(), real_batch); + return (it != all_batches_.end()) ? *it : all_batches_.back(); +} + +int64_t XlaBatchMatcher::get_xla_compile_batch(int64_t real_batch) { + // Match target batch size + int64_t selected = find_min_larger_batch(real_batch); + if (real_batch != last_batch_ || all_batches_.empty()) { + last_batch_ = real_batch; + VLOG(2) << "[XLA_BATCH_INFO] Real batch: " << real_batch + << " -> Selected compile batch: " << selected; + } + return selected; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_batch_matcher.h b/tensorflow/compiler/jit/xla_batch_matcher.h new file mode 100644 index 00000000000000..0ef9a2dd13afcd --- /dev/null +++ b/tensorflow/compiler/jit/xla_batch_matcher.h @@ -0,0 +1,40 @@ +#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_BATCH_MATCHER_H_ +#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_BATCH_MATCHER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tensorflow { + +// Define the maximum allowed batch size: 2147483648 >> 1 = 1073741824 +// Exceeding this value will make the next power of two exceed the safe range +constexpr int kMaxBatch = 2147483648ULL >> 1; + +class XlaBatchMatcher { + public: + XlaBatchMatcher(); + virtual ~XlaBatchMatcher() = default; + int64_t get_xla_compile_batch(int64_t real_batch); + std::vector get_all_batches() { return all_batches_; } + + private: + void parse_env_config(); + void print_all_batches(); + std::vector parse_single_item(const std::string& item); + int64_t find_min_larger_batch(int64_t real_batch); + + std::vector all_batches_; + const char* env_str_; + int64_t last_batch_ = -1; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_BATCH_MATCHER_H_ \ No newline at end of file diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 588161df309131..06196e8d78f26c 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -47,6 +47,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.h" #include "tensorflow/core/common_runtime/gpu_device_context.h" #include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/batch_size_resource.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" @@ -361,7 +362,8 @@ absl::Status XlaComputationLaunchContext::PopulateOutputs( ScopedShapedBuffer output, int missing_ctx_input_prefix, absl::Span variable_infos, const xla::HloInputOutputAliasConfig& input_output_alias, - const std::map& resource_vars) { + const std::map& resource_vars, + const xla::ExecutableRunOptions* run_options) { se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; Allocator* allocator = ctx->device()->GetAllocator({}); @@ -430,10 +432,52 @@ absl::Status XlaComputationLaunchContext::PopulateOutputs( } } else { for (int i = 0; i < ctx->num_outputs(); ++i) { - output_tensor_shapes.push_back(compilation_result->outputs[i].shape); + xla::Shape output_host_shape = output.on_host_shape(); + const xla::Shape& subshape = + xla::ShapeUtil::GetSubshape(output_host_shape, {i}); + VLOG(2) << "PopulateOutputs: subshape[" << i << "]: "<< subshape; + TensorShape shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape)); + bool has_dynamic = false; + + for (int dim = 0; dim < subshape.expressions().size(); ++dim) { + auto expr = subshape.expressions(dim); + if (expr != nullptr && expr->is_dynamic()) { + has_dynamic = true; + VLOG(1) << "Current expression is " << expr; + if (run_options) { + xla::DynExpr* batch_size = xla::DynExpr::_(run_options->batch_size()); + xla::DynExpr* subst_expr = expr->substitute(1, batch_size)->s(); + shape.set_dim(dim, subst_expr->get_val()); + } else { + // TODO: Fallback to BatchSizeResource for now. Remove it later. + VLOG(1) << "Warning: Didn't find run_options"; + BatchSizeResource* bsr = nullptr; + ScopedStepContainer* step_container = ctx->step_container(); + TF_RETURN_IF_ERROR(step_container->Lookup( + ctx->resource_manager(), BatchSizeResourceName, &bsr)); + xla::DynExpr* batch_size = xla::DynExpr::_(bsr->GetBatchSize()); + // Just substitute Var(1) for now. + xla::DynExpr* subst_expr = expr->substitute(1, batch_size)->s(); + shape.set_dim(dim, subst_expr->get_val()); + bsr->Unref(); + } + } + } + if (has_dynamic) { + output_tensor_shapes.push_back(shape); + } + else { + output_tensor_shapes.push_back(compilation_result->outputs[i].shape); + } } } + VLOG(2) << "output_tensor_shapes:"; + for (auto s:output_tensor_shapes) { + VLOG(2) << s; + } + // Copy XLA results to the OpOutputList. int output_num = 0; for (int i = 0, end = ctx->num_outputs(); i < end; ++i) { diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 5e5128d515bf97..a2986c704f51ac 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -29,6 +29,7 @@ limitations under the License. #include "xla/client/local_client.h" #include "xla/pjrt/pjrt_client.h" #include "xla/service/shaped_buffer.h" +#include "xla/executable_run_options.h" #include "xla/stream_executor/device_memory_allocator.h" #include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/tensor.h" @@ -208,7 +209,8 @@ class XlaComputationLaunchContext { xla::ScopedShapedBuffer output, int missing_ctx_input_prefix, absl::Span variable_infos, const xla::HloInputOutputAliasConfig& input_output_alias, - const std::map& resource_vars); + const std::map& resource_vars, + const xla::ExecutableRunOptions* run_options = nullptr); private: xla::LocalClient* client_; diff --git a/tensorflow/compiler/tf2xla/kernels/beta_op.cc b/tensorflow/compiler/tf2xla/kernels/beta_op.cc index 4ead9f76fcee11..9159df4f6a4ad5 100644 --- a/tensorflow/compiler/tf2xla/kernels/beta_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/beta_op.cc @@ -63,11 +63,14 @@ class BetaincOp : public XlaOpKernel { auto result = builder->ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN( - auto a, BroadcastTo(ctx->Input(0), merged_shape.dim_sizes())); + auto a, BroadcastTo(ctx->Input(0), merged_shape.dim_sizes(), + merged_shape.get_expressions())); TF_ASSIGN_OR_RETURN( - auto b, BroadcastTo(ctx->Input(1), merged_shape.dim_sizes())); + auto b, BroadcastTo(ctx->Input(1), merged_shape.dim_sizes(), + merged_shape.get_expressions())); TF_ASSIGN_OR_RETURN( - auto x, BroadcastTo(ctx->Input(2), merged_shape.dim_sizes())); + auto x, BroadcastTo(ctx->Input(2), merged_shape.dim_sizes(), + merged_shape.get_expressions())); return xla::RegularizedIncompleteBeta(a, b, x); }); ctx->SetOutput(0, result); diff --git a/tensorflow/compiler/tf2xla/kernels/bincount_op.cc b/tensorflow/compiler/tf2xla/kernels/bincount_op.cc index fadd2f87219464..4f31c79f91a719 100644 --- a/tensorflow/compiler/tf2xla/kernels/bincount_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/bincount_op.cc @@ -117,9 +117,13 @@ class DenseBincountOp : public XlaOpKernel { xla::ShapeUtil::MakeShape(input_xla_type, {input_shape.dimensions()}); auto i = xla::Iota(ctx->builder(), i_shape, 0); i = xla::Reshape( - i, {input_shape.dimensions(0) * input_shape.dimensions(1), 1}); + i, {input_shape.dimensions(0) * input_shape.dimensions(1), 1}, + {(*input_shape.expressions(0) * *input_shape.expressions(1))->s(), + xla::DynExpr::one}); auto j = xla::Reshape( - input, {input_shape.dimensions(0) * input_shape.dimensions(1), 1}); + input, {input_shape.dimensions(0) * input_shape.dimensions(1), 1}, + {(*input_shape.expressions(0) * *input_shape.expressions(1))->s(), + xla::DynExpr::one}); std::vector iotas_to_concat; iotas_to_concat.push_back(i); iotas_to_concat.push_back(j); @@ -130,7 +134,8 @@ class DenseBincountOp : public XlaOpKernel { zero, {output_shape.dimensions(0), output_shape.dimensions(1)}); if (has_weights && !binary_output_) { weights = xla::Reshape( - weights, {input_shape.dimensions(0) * input_shape.dimensions(1)}); + weights, {input_shape.dimensions(0) * input_shape.dimensions(1)}, + {(*input_shape.expressions(0) * *input_shape.expressions(1))->s()}); updates = weights; } } else { diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc index 975179466bf104..8176c4c820cb40 100644 --- a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc @@ -37,7 +37,8 @@ class BroadcastToOp : public XlaOpKernel { context->ConstantInputAsShape( 1, &output_shape, xla::ValueInferenceMode::kUpperBound)); auto output_status_or = - BroadcastTo(context->Input(0), output_shape.dim_sizes()); + BroadcastTo(context->Input(0), output_shape.dim_sizes(), + output_shape.get_expressions()); OP_REQUIRES_OK(context, output_status_or.status()); auto output = output_status_or.value(); std::vector dynamic_dims; diff --git a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc index 6b4f278c72beff..e3459681a98b13 100644 --- a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc @@ -46,11 +46,11 @@ class ClipByValueOp : public XlaOpKernel { if (shape != min_shape) { OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(min_shape), shape_error()); - min = xla::Broadcast(min, shape.dim_sizes()); + min = xla::Broadcast(min, shape.dim_sizes(), shape.get_expressions()); } if (shape != max_shape) { OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(max_shape), shape_error()); - max = xla::Broadcast(max, shape.dim_sizes()); + max = xla::Broadcast(max, shape.dim_sizes(), shape.get_expressions()); } ctx->SetOutput(0, xla::Clamp(min, input, max)); } diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index d2463a9974b1bb..043a263b9f2605 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -123,7 +123,8 @@ class ConstOp : public XlaOpKernel { if (shape.num_elements() > 1) { xla::XlaOp value = GetScalarConst(proto_, b); if (value.valid()) { - ctx->SetOutput(0, xla::Broadcast(value, shape.dim_sizes())); + ctx->SetOutput(0, xla::Broadcast(value, shape.dim_sizes(), + shape.get_expressions())); return; } } diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index 3fe22dcb4441e7..466707f0d777e2 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -97,8 +97,11 @@ xla::XlaOp TransposeFilterForGroupConvolutionBackpropInput( CHECK_GE(num_dims, 2); // Crash OK xla::Shape new_shape = filter_shape; new_shape.set_dimensions(num_dims - 1, num_groups); - new_shape.add_dimensions(filter_shape.dimensions(num_dims - 1) / num_groups); - xla::XlaOp result = xla::Reshape(filter, new_shape.dimensions()); + new_shape.add_dimensions( + filter_shape.dimensions(num_dims - 1) / num_groups, + (*filter_shape.expressions(num_dims - 1) / num_groups)->s()); + xla::XlaOp result = + xla::Reshape(filter, new_shape.dimensions(), new_shape.expressions()); // 2. Transpose to [H, W, ..., G, filter_in_depth, out_depth / G] std::vector transpose_dims(num_dims + 1); @@ -118,7 +121,8 @@ xla::XlaOp ReshapeFilterForDepthwiseConvolution(const xla::Shape& filter_shape, xla::XlaOp filter) { return xla::Reshape( filter, - GroupedFilterShapeForDepthwiseConvolution(filter_shape).dimensions()); + GroupedFilterShapeForDepthwiseConvolution(filter_shape).dimensions(), + GroupedFilterShapeForDepthwiseConvolution(filter_shape).expressions()); } // Performs some basic checks on ConvOpAttrs that are true for all kinds of XLA @@ -603,7 +607,8 @@ absl::StatusOr MakeXlaBackpropFilterConvOp( } if (attrs.depthwise) { - filter_backprop = xla::Reshape(filter_backprop, filter_shape.dimensions()); + filter_backprop = xla::Reshape(filter_backprop, filter_shape.dimensions(), + filter_shape.expressions()); } return filter_backprop; diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index b1da0acd61608f..f90bad207cfbd4 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -110,7 +110,8 @@ class ConvNDOp : public XlaOpKernel { expanded_input_shape.set_dimensions(i + 1, input_shape.dimensions(i)); } expanded_input_shape.set_dimensions(0, 1); - input = xla::Reshape(input, expanded_input_shape.dimensions()); + input = xla::Reshape(input, expanded_input_shape.dimensions(), + expanded_input_shape.expressions()); } else if (attrs_.batch_dims > 1) { // Flatten batch_dims. std::vector to_collapse(attrs_.batch_dims); @@ -131,7 +132,8 @@ class ConvNDOp : public XlaOpKernel { if (attrs_.batch_dims == 0) { xla::Shape no_batch_shape(out_shape); no_batch_shape.DeleteDimension(0); - out = xla::Reshape(out, no_batch_shape.dimensions()); + out = xla::Reshape(out, no_batch_shape.dimensions(), + no_batch_shape.expressions()); } else if (attrs_.batch_dims > 1) { xla::Shape expanded_out_shape(input_shape); for (int i = attrs_.batch_dims; i < input_shape.dimensions().size(); @@ -139,7 +141,8 @@ class ConvNDOp : public XlaOpKernel { expanded_out_shape.set_dimensions( i, out_shape.dimensions(i - (attrs_.batch_dims - 1))); } - out = xla::Reshape(out, expanded_out_shape.dimensions()); + out = xla::Reshape(out, expanded_out_shape.dimensions(), + expanded_out_shape.expressions()); } ctx->SetOutput(0, out); } diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index 404fa9f5e04e45..d8740aa137fbeb 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -103,7 +103,7 @@ class DiagOp : public XlaOpKernel { // Flattens the input to 1D. int64_t size = input_shape.num_elements(); - input = xla::Reshape(input, {size}); + input = xla::Reshape(input, {size}, {}); // Create an R2 with the R1 diagonal. xla::XlaOp diag = CreateDiagonal(input, size, /*other_dims=*/{}); diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc index ceeea010ee7858..84c091339c1804 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc @@ -140,8 +140,9 @@ class DynamicPartitionOp : public XlaOpKernel { for (int64_t i = 0; i < rank; ++i) { broadcasted_dims.push_back(i); } - partitions = xla::BroadcastInDim(partitions, data_shape.dimensions(), - broadcasted_dims); + partitions = + xla::BroadcastInDim(partitions, data_shape.dimensions(), + broadcasted_dims, data_shape.expressions()); } // Output shape bounded is calculated by diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index cb7e4f6f96437e..6674fc6cde0793 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -174,9 +174,12 @@ class DynamicStitchOp : public XlaOpKernel { TensorShape new_shape; // first reshaped dimension is the number of indices for this input. new_shape.AddDim(indices[input_num].shape().dimensions(0)); + new_shape.AddExpression( + xla::DynExpr::_(indices[input_num].shape().dimensions(0))); // Then the rest are the common extra shape. for (int d = indices0_shape.dims(); d < data0_shape.dims(); d++) { new_shape.AddDim(data0_shape.dim_size(d)); + new_shape.AddExpression(data0_shape.get_expression(d)); } // Get the data, shaped appropriately. auto handle = data[input_num]; diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc index 2a65441eb79bf9..d19e9f76f2c9b0 100644 --- a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc @@ -127,7 +127,8 @@ class FakeQuantWithMinMaxArgsGradOp : public XlaOpKernel { xla::XlaOp between_nudged_min_max = xla::And( xla::Le(nudged_input_min, input), xla::Le(input, nudged_input_max)); xla::XlaOp zeroes = xla::Broadcast(XlaHelpers::Zero(b, data_type), - gradient_shape.dim_sizes()); + gradient_shape.dim_sizes(), + gradient_shape.get_expressions()); xla::XlaOp output = xla::Select(between_nudged_min_max, gradient, zeroes); ctx->SetOutput(0, output); } @@ -213,7 +214,8 @@ class FakeQuantWithMinMaxVarsGradOp : public XlaOpKernel { xla::XlaOp between_nudged_min_max = xla::And( xla::Le(nudged_input_min, input), xla::Le(input, nudged_input_max)); xla::XlaOp zero = XlaHelpers::Zero(b, data_type); - xla::XlaOp zeroes = xla::Broadcast(zero, gradient_shape.dim_sizes()); + xla::XlaOp zeroes = xla::Broadcast(zero, gradient_shape.dim_sizes(), + gradient_shape.get_expressions()); xla::XlaOp output0 = xla::Select(between_nudged_min_max, gradient, zeroes); ctx->SetOutput(0, output0); @@ -268,9 +270,12 @@ class FakeQuantWithMinMaxVarsPerChannelOp : public XlaOpKernel { xla::Shape input_shape = b->GetShape(input).value(); absl::Span input_dimensions = input_shape.dimensions(); + absl::Span input_expressions = + input_shape.expressions(); auto convert_to_input_shape = [&](const xla::XlaOp op) { return xla::BroadcastInDim(op, input_dimensions, - {input_shape.dimensions_size() - 1}); + {input_shape.dimensions_size() - 1}, + input_expressions); }; input_min = convert_to_input_shape(input_min); input_max = convert_to_input_shape(input_max); @@ -323,6 +328,8 @@ class FakeQuantWithMinMaxVarsPerChannelGradOp : public XlaOpKernel { xla::XlaBuilder* b = ctx->builder(); xla::Shape input_shape = b->GetShape(input).value(); absl::Span input_dimensions = input_shape.dimensions(); + absl::Span input_expressions = + input_shape.expressions(); std::vector reduce_axes; for (int64_t i = 0; i + 1 < input_shape.dimensions_size(); ++i) { @@ -331,7 +338,8 @@ class FakeQuantWithMinMaxVarsPerChannelGradOp : public XlaOpKernel { auto convert_to_input_shape = [&](const xla::XlaOp op) { return xla::BroadcastInDim(op, input_dimensions, - {input_shape.dimensions_size() - 1}); + {input_shape.dimensions_size() - 1}, + input_expressions); }; input_min = convert_to_input_shape(input_min); input_max = convert_to_input_shape(input_max); @@ -343,7 +351,8 @@ class FakeQuantWithMinMaxVarsPerChannelGradOp : public XlaOpKernel { xla::XlaOp between_nudged_min_max = xla::And( xla::Le(nudged_input_min, input), xla::Le(input, nudged_input_max)); xla::XlaOp zero = XlaHelpers::Zero(b, data_type); - xla::XlaOp zeroes = xla::Broadcast(zero, gradient_shape.dim_sizes()); + xla::XlaOp zeroes = xla::Broadcast(zero, gradient_shape.dim_sizes(), + gradient_shape.get_expressions()); xla::XlaOp output0 = xla::Select(between_nudged_min_max, gradient, zeroes); ctx->SetOutput(0, output0); diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 2783951e1b6b0f..e87155821c523f 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -88,7 +88,8 @@ absl::Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, out_shape.AppendShape(input_shape_post_axis); *gather_output = - xla::Broadcast(XlaHelpers::Zero(builder, dtype), out_shape.dim_sizes()); + xla::Broadcast(XlaHelpers::Zero(builder, dtype), out_shape.dim_sizes(), + out_shape.get_expressions()); return absl::OkStatus(); } diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index a8eb7bbf794268..162ce34488a620 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -59,7 +59,7 @@ std::array RGBToHSV(XlaOpKernelContext* ctx, xla::XlaBuilder* b, auto minimum = xla::Min(xla::Min(red, green), blue); auto range = xla::Sub(value, minimum); - auto zeros = xla::Broadcast(zero, shape.dim_sizes()); + auto zeros = xla::Broadcast(zero, shape.dim_sizes(), shape.get_expressions()); auto saturation = xla::Select(xla::Gt(value, zero), xla::Div(range, value), zeros); diff --git a/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc b/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc index f357262a39c35b..e6e90965dbffcf 100644 --- a/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc @@ -87,9 +87,11 @@ class InTopKOp : public XlaOpKernel { // which indicates the target is in topk. xla::XlaOp gt_r2 = xla::Gt(predictions_r2, targets_values_r1, {0}); xla::XlaOp zero_r0 = xla::Zero(xla_builder, xla::S32); - xla::XlaOp zero_r2 = xla::Broadcast(zero_r0, predictions_shape.dim_sizes()); + xla::XlaOp zero_r2 = xla::Broadcast(zero_r0, predictions_shape.dim_sizes(), + predictions_shape.get_expressions()); xla::XlaOp one_r0 = xla::One(xla_builder, xla::S32); - xla::XlaOp one_r2 = xla::Broadcast(one_r0, predictions_shape.dim_sizes()); + xla::XlaOp one_r2 = xla::Broadcast(one_r0, predictions_shape.dim_sizes(), + predictions_shape.get_expressions()); xla::XlaOp one_hot_r2 = xla::Select(gt_r2, one_r2, zero_r2); xla::XlaOp num_gt_r1 = xla::Reduce( one_hot_r2, zero_r0, diff --git a/tensorflow/compiler/tf2xla/kernels/lower_upper_bound_ops.cc b/tensorflow/compiler/tf2xla/kernels/lower_upper_bound_ops.cc index 46e46f6d8b3d32..3ff4c4b0c6a221 100644 --- a/tensorflow/compiler/tf2xla/kernels/lower_upper_bound_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/lower_upper_bound_ops.cc @@ -49,14 +49,16 @@ void BuildLowerUpperBoundOp(XlaOpKernelContext* ctx, DataType out_dtype, // dimension of sorted_sequence. auto new_values_shape = values_shape; new_values_shape.InsertDim(/* d */ 2, /* size */ 1); - auto values_reshaped = xla::Reshape(values, new_values_shape.dim_sizes()); + auto values_reshaped = xla::Reshape(values, new_values_shape.dim_sizes(), + new_values_shape.get_expressions()); // Add a new penultimate dimension to sorted_inputs, to allow broadcasting of // sorted_sequence entries for each value. auto new_sorted_inputs_shape = sorted_inputs_shape; new_sorted_inputs_shape.InsertDim(/* d */ 1, /* size */ 1); auto sorted_inputs_reshaped = - xla::Reshape(sorted_inputs, new_sorted_inputs_shape.dim_sizes()); + xla::Reshape(sorted_inputs, new_sorted_inputs_shape.dim_sizes(), + new_sorted_inputs_shape.get_expressions()); // We are relying on broadcasting to compare each value against each entry in // the associated sorted_inputs row. diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc b/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc index 48e8f976cc67bb..54a61a33d448a0 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc @@ -234,7 +234,8 @@ xla::XlaOp SetMatrixDiag(const xla::XlaOp input, const xla::XlaOp diag, // Broadcast and mask. xla::XlaOp diag_broadcast = xla::BroadcastInDim( - diag_slice, input_shape.dim_sizes(), broadcast_dimensions); + diag_slice, input_shape.dim_sizes(), broadcast_dimensions, + input_shape.get_expressions()); const auto mask = xla::GetDiagonalMask(output, diag_index); output = xla::Select(mask, diag_broadcast, output); } @@ -327,8 +328,11 @@ class MatrixDiagOp : public XlaOpKernel { TensorShape output_shape = diag_shape; output_shape.RemoveLastDims((num_diags == 1) ? 1 : 2); output_shape.AddDim(num_rows); + output_shape.AddExpression(xla::DynExpr::_(num_rows)); output_shape.AddDim(num_cols); - xla::XlaOp output = xla::Broadcast(padding_value, output_shape.dim_sizes()); + output_shape.AddExpression(xla::DynExpr::_(num_cols)); + xla::XlaOp output = xla::Broadcast(padding_value, output_shape.dim_sizes(), + output_shape.get_expressions()); xla::XlaOp diag = context->Input(0); context->SetOutput( 0, SetMatrixDiag(output, diag, output_shape, diag_rank, num_diags, @@ -404,11 +408,15 @@ class MatrixDiagPartOp : public XlaOpKernel { TensorShape output_shape = input_shape; output_shape.RemoveLastDims(2); const int num_diags = upper_diag_index - lower_diag_index + 1; - if (num_diags > 1) output_shape.AddDim(num_diags); + if (num_diags > 1) { + output_shape.AddDim(num_diags); + output_shape.AddExpression(xla::DynExpr::_(num_diags)); + } const int32_t max_diag_len = std::min(num_rows + std::min(upper_diag_index, int64_t{0}), num_cols - std::max(lower_diag_index, int64_t{0})); output_shape.AddDim(max_diag_len); + output_shape.AddExpression(xla::DynExpr::_(max_diag_len)); // Computes output. xla::XlaOp input = context->Input(0); @@ -447,7 +455,8 @@ class MatrixDiagPartOp : public XlaOpKernel { } auto concat = xla::ConcatInDim(context->builder(), diag_list, input_rank - 2); - context->SetOutput(0, xla::Reshape(concat, output_shape.dim_sizes())); + context->SetOutput(0, xla::Reshape(concat, output_shape.dim_sizes(), + output_shape.get_expressions())); } private: @@ -519,11 +528,15 @@ class MatrixSetDiagOp : public XlaOpKernel { TensorShape expected_diag_shape = input_shape; expected_diag_shape.RemoveLastDims(2); - if (num_diags > 1) expected_diag_shape.AddDim(num_diags); + if (num_diags > 1) { + expected_diag_shape.AddDim(num_diags); + expected_diag_shape.AddExpression(xla::DynExpr::_(num_diags)); + } const int32_t max_diag_len = std::min(num_rows + std::min(upper_diag_index, int64_t{0}), num_cols - std::max(lower_diag_index, int64_t{0})); expected_diag_shape.AddDim(max_diag_len); + expected_diag_shape.AddExpression(xla::DynExpr::_(max_diag_len)); OP_REQUIRES( context, expected_diag_shape == diag_shape, errors::InvalidArgument( diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc index 17b5ae7a70375a..4d1139ed76b460 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc @@ -96,8 +96,11 @@ MatrixTriangularSolveOp::Broadcast(xla::XlaOp lhs, const TensorShape& lhs_shape, TensorShape lhs_broadcast_shape(broadcast_helper.output_batch_shape()); lhs_broadcast_shape.AddDim(m); + lhs_broadcast_shape.AddExpression(xla::DynExpr::_(m)); lhs_broadcast_shape.AddDim(m); - auto lhs_output = BroadcastTo(lhs, lhs_broadcast_shape.dim_sizes()); + lhs_broadcast_shape.AddExpression(xla::DynExpr::_(m)); + auto lhs_output = BroadcastTo(lhs, lhs_broadcast_shape.dim_sizes(), + lhs_broadcast_shape.get_expressions()); if (!lhs_output.ok()) { xla::XlaOp error = lhs.builder()->ReportError(lhs_output.status()); return {error, error}; @@ -105,8 +108,11 @@ MatrixTriangularSolveOp::Broadcast(xla::XlaOp lhs, const TensorShape& lhs_shape, TensorShape rhs_broadcast_shape(broadcast_helper.output_batch_shape()); rhs_broadcast_shape.AddDim(m); + rhs_broadcast_shape.AddExpression(xla::DynExpr::_(m)); rhs_broadcast_shape.AddDim(n); - auto rhs_output = BroadcastTo(rhs, rhs_broadcast_shape.dim_sizes()); + rhs_broadcast_shape.AddExpression(xla::DynExpr::_(n)); + auto rhs_output = BroadcastTo(rhs, rhs_broadcast_shape.dim_sizes(), + rhs_broadcast_shape.get_expressions()); if (!rhs_output.ok()) { xla::XlaOp error = rhs.builder()->ReportError(rhs_output.status()); return {error, error}; diff --git a/tensorflow/compiler/tf2xla/kernels/pack_op.cc b/tensorflow/compiler/tf2xla/kernels/pack_op.cc index ba4e8bbef7b136..3d84437dba7869 100644 --- a/tensorflow/compiler/tf2xla/kernels/pack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pack_op.cc @@ -66,9 +66,17 @@ class PackOp : public XlaOpKernel { TensorShape child_shape(shapes[0]); child_shape.InsertDim(axis, 1); + // Equivalent to InsertDim(axis, 1) for expressions + std::vector exprs; + for (auto e : child_shape.get_expressions()) { + exprs.push_back(e); + } + exprs.insert(exprs.begin() + axis, xla::DynExpr::one); + for (int i = 0; i < num; ++i) { // Reshape the inputs to have an extra dimension of size 1. - reshaped_inputs[i] = xla::Reshape(values[i], child_shape.dim_sizes()); + reshaped_inputs[i] = xla::Reshape(values[i], child_shape.dim_sizes(), + exprs); } ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), reshaped_inputs, axis)); diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index aa7c78b8b8f97a..b3cdb7bac0c7dc 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -264,10 +264,14 @@ class MaxPoolOp : public PoolingOp { absl::InlinedVector new_dims(result_shape->dimensions().begin(), result_shape->dimensions().end()); + absl::InlinedVector new_exprs( + result_shape->expressions().begin(), + result_shape->expressions().end()); new_dims[1] /= *vect_width; + new_exprs[1] = *new_exprs[1] / *vect_width; new_dims.insert(new_dims.begin() + 2, *vect_width); - pooling = - xla::Transpose(xla::Reshape(pooling, new_dims), {0, 1, 3, 4, 2}); + pooling = xla::Transpose(xla::Reshape(pooling, new_dims, new_exprs), + {0, 1, 3, 4, 2}); } ctx->SetOutput(0, pooling); diff --git a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc index cac9f8a68f234e..ba6860caa9cb1d 100644 --- a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc @@ -173,8 +173,11 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { if (!xla::ShapeUtil::IsScalar(axis_shape)) { xla::Shape input_shape = b->GetShape(input).value(); absl::Span input_dimensions = input_shape.dimensions(); + absl::Span input_expressions = + input_shape.expressions(); auto convert_to_input_shape = [&](const xla::XlaOp op) { - return xla::BroadcastInDim(op, input_dimensions, {axis_}); + return xla::BroadcastInDim(op, input_dimensions, {axis_}, + input_expressions); }; min_range = convert_to_input_shape(min_range); max_range = convert_to_input_shape(max_range); diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index 5f911018c244b5..be6b6e077656b1 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/status/status.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/hlo/builder/lib/constants.h" @@ -147,10 +148,21 @@ class MeanOp : public XlaReductionOp { xla::XlaOp result = reduce_output; xla::Shape bounded_shape = builder->GetShape(input).value(); int64_t divisor_value = bounded_shape.dimensions(dimensions_to_reduce[0]); - auto divisor = xla::GetDimensionSize(input, dimensions_to_reduce[0]); + xla::XlaOp divisor; + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); + if (flags->tf_xla_enable_dynamic_sizes && dimensions_to_reduce[0] == 0) { + divisor = xla::GetOuterBatchValue(input); + } else { + divisor = xla::GetDimensionSize(input, dimensions_to_reduce[0]); + } for (int i = 1; i < dimensions_to_reduce.size(); i++) { int64_t size_value = bounded_shape.dimensions(dimensions_to_reduce[i]); - auto size = xla::GetDimensionSize(input, dimensions_to_reduce[i]); + xla::XlaOp size; + if (flags->tf_xla_enable_dynamic_sizes && dimensions_to_reduce[i] == 0) { + size = xla::GetOuterBatchValue(input); + } else { + size = xla::GetDimensionSize(input, dimensions_to_reduce[i]); + } if (size_value * divisor_value > std::numeric_limits::max()) { result = result / xla::ConvertElementType(divisor, xla_reduction_type_); divisor_value = size_value; diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index 6a8a98342c1123..9ed1116c08dc7a 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -106,16 +106,19 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { } std::vector final_shape; + std::vector final_exprs; for (int i = 0; i < data_shape.dims(); ++i) { if (!bitmap[i]) { // If we are not reducing along dimension i. int64_t dim = data_shape.dim_size(i); final_shape.push_back(dim); + final_exprs.push_back(data_shape.get_expression(i)); } else if (keep_dims_) { // We are reducing along dimension i, but we want to keep the // same number of dimensions, so we set the dimension of i to // '1'. final_shape.push_back(1); + final_exprs.push_back(xla::DynExpr::one); } } @@ -139,7 +142,8 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { auto reduce = xla::Reduce(data, initial, reduction_computation, xla_axes); auto finalized = BuildFinalizer(b, data, reduce, xla_axes); - auto result = keep_dims_ ? xla::Reshape(finalized, final_shape) : finalized; + auto result = keep_dims_ ? xla::Reshape(finalized, final_shape, final_exprs) + : finalized; ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc index f274b271596ff5..9a7c875347f120 100644 --- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc @@ -40,8 +40,28 @@ XlaOp Relu6(XlaOp x) { namespace tensorflow { namespace { -REGISTER_XLA_OP(Name("Relu"), MlirXlaOpKernel); -REGISTER_XLA_OP(Name("Relu6"), MlirXlaOpKernel); +class ReluOp : public XlaOpKernel { + public: + explicit ReluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + // Computes the max of the scalar input x and 0. + void Compile(XlaOpKernelContext* ctx) override { + ctx->SetOutput(0, xla::Relu(ctx->Input(0))); + } +}; + +class Relu6Op : public XlaOpKernel { + public: + explicit Relu6Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + // Computes the max of the scalar input x and 0. + void Compile(XlaOpKernelContext* ctx) override { + ctx->SetOutput(0, xla::Relu6(ctx->Input(0))); + } +}; + +REGISTER_XLA_OP(Name("Relu"), ReluOp); +// REGISTER_XLA_OP(Name("Relu"), MlirXlaOpKernel); +// REGISTER_XLA_OP(Name("Relu6"), MlirXlaOpKernel); +REGISTER_XLA_OP(Name("Relu6"), Relu6Op); class LeakyReluOp : public XlaOpKernel { public: @@ -70,9 +90,11 @@ class Relu6GradOp : public XlaOpKernel { xla::XlaBuilder* b = ctx->builder(); const TensorShape shape = ctx->InputShape(0); const auto zero = - xla::Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); - const auto six = xla::Broadcast( - XlaHelpers::IntegerLiteral(b, input_type(0), 6), shape.dim_sizes()); + xla::Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes(), + shape.get_expressions()); + const auto six = + xla::Broadcast(XlaHelpers::IntegerLiteral(b, input_type(0), 6), + shape.dim_sizes(), shape.get_expressions()); auto out = xla::Select( xla::And(xla::Lt(ctx->Input(1), six), xla::Gt(ctx->Input(1), zero)), ctx->Input(0), zero); diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index ba17d1b295b763..7f444777cd4cdb 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -57,8 +57,10 @@ class ReshapeOp : public XlaOpKernel { // is one. TensorShape shape; int64_t product = 1; + xla::DynExpr* product_expr = xla::DynExpr::one; int unknown_index = -1; bool shape_has_zero_dim = false; + int ratio = 1; for (int d = 0; d < num_dims; ++d) { const int64_t size = shape_input[d]; if (size == -1) { @@ -68,23 +70,62 @@ class ReshapeOp : public XlaOpKernel { unknown_index, " and ", d)); unknown_index = d; shape.AddDim(1); + shape.AddExpression(xla::DynExpr::one); + ratio = 1; } else if (size == 0) { // We don't include zero-sized dimension in product, so that we can // still calculate number of elements for non-zero-sized dimensions and // therefore infer their shapes. shape.AddDim(size); + shape.AddExpression(xla::DynExpr::_(size)); shape_has_zero_dim = true; } else { + xla::DynExpr* size_expr; OP_REQUIRES(ctx, size >= 0, errors::InvalidArgument( "size ", d, " must be non-negative, not ", size)); shape.AddDim(size); + xla::DynExpr* input_expr = + d < input_shape.dims() ? input_shape.get_expression(d) : nullptr; + if (input_expr != nullptr && input_expr->is_dynamic()) { + int old = input_shape.dim_size(d); + bool is_split = (old > size); + int local_ratio = ratio * (is_split ? old / size : size / old); + xla::DynExpr* new_expr = + (size > old) + ? *input_expr * + *xla::DynExpr::_(local_ratio) // Split [xy] -> [x/y,y] + : *input_expr / + *xla::DynExpr::_(local_ratio); // Reduce [x,y] -> [x*y] + + // Pass ratio to next dimension if this is a split, otherwise just + // reset it to 1. + ratio = is_split ? local_ratio : 1; + size_expr = new_expr->s(); + + } else { + size_expr = xla::DynExpr::_(size); + if (ratio != 1) { + // A split dynamic dimension can be materialized by multiple later + // known dimensions. Any unresolved remainder is kept in `ratio` + // and may be consumed by a subsequent `-1` dimension (if present); + // otherwise, it remains unapplied. + if (ratio % size == 0) { + ratio /= size; + } else if (size % ratio == 0) { + ratio = 1; + } + } + } + shape.AddExpression(size_expr); product *= size; + product_expr = (*product_expr * *size_expr); } } auto input = ctx->Input(0); if (unknown_index != -1) { int64_t input_num_elements = 1; + xla::DynExpr* input_num_elements_expr = xla::DynExpr::one; bool input_has_zero_dim = false; for (int dim = 0; dim < input_shape.dims(); dim++) { // For zero dimension, we don't count it into `input_num_elements` @@ -92,12 +133,17 @@ class ReshapeOp : public XlaOpKernel { // infer shapes for other dimensions. if (input_shape.dim_size(dim) > 0 || !shape_has_zero_dim) { input_num_elements *= input_shape.dim_size(dim); + input_num_elements_expr = + (*input_num_elements_expr * *input_shape.get_expression(dim))->s(); } else { input_has_zero_dim = true; } } int64_t missing = input_num_elements / product; + input_num_elements_expr = input_num_elements_expr->s(); + product_expr = product_expr->s(); + auto missing_expr = *input_num_elements_expr / *product_expr; if (!input_has_zero_dim) { if (input_xla_shape->is_static() || input_xla_shape->dimensions().size() != 1) { @@ -119,10 +165,19 @@ class ReshapeOp : public XlaOpKernel { input, xla::Zero(ctx->builder(), input_xla_shape->element_type()), 0, 0, padded_input_num - input_num_elements); input_shape.set_dim(0, padded_input_num); + // This expression only approximates the padded size: the true value + // uses ceil(input_num_elements / product) * product, which we do not + // model symbolically here. + xla::DynExpr* padded_input_num_expr = + (*(*input_num_elements_expr / *product_expr) * *product_expr)->s(); + input_shape.set_expression(0, padded_input_num_expr); } } shape.set_dim(unknown_index, missing); + shape.set_expression( + unknown_index, missing_expr->s()); } + OP_REQUIRES(ctx, shape.num_elements() == input_shape.num_elements(), errors::InvalidArgument("Input to reshape is a tensor with ", input_shape.num_elements(), @@ -131,19 +186,23 @@ class ReshapeOp : public XlaOpKernel { VLOG(2) << "Reshape from " << input_shape.DebugString() << " to " << shape.DebugString() << ", unknown_index=" << unknown_index; + if (input_xla_shape->is_static()) { - ctx->SetOutput(0, xla::Reshape(input, shape.dim_sizes())); + ctx->SetOutput( + 0, xla::Reshape(input, shape.dim_sizes(), shape.get_expressions())); return; } std::vector output_dim_sizes; std::vector dims_are_dynamic; + std::vector output_dim_exprs; const auto& dims = shape.dims(); dims_are_dynamic.reserve(dims); output_dim_sizes.reserve(dims); for (int64_t i = 0; i < dims; ++i) { output_dim_sizes.push_back( xla::Reshape(xla::Slice(ctx->Input(1), {i}, {i + 1}, {1}), {})); + output_dim_exprs.push_back(xla::DynExpr::_(-111)); } OP_REQUIRES_OK( ctx, ctx->ResolveInputDynamismIntoPredVector(1, &dims_are_dynamic)); @@ -151,7 +210,7 @@ class ReshapeOp : public XlaOpKernel { // No unknown index. ctx->SetOutput( 0, xla::DynamicReshape(input, output_dim_sizes, shape.dim_sizes(), - dims_are_dynamic)); + dims_are_dynamic, output_dim_exprs)); return; } auto common_factors = @@ -166,21 +225,26 @@ class ReshapeOp : public XlaOpKernel { // reshape(Tensor([2, 3, 3]), [3, -1, 3]) product of the group // containing -1 will be 6. xla::XlaOp product = xla::One(ctx->builder(), xla::S32); + xla::DynExpr* product_expr = xla::DynExpr::one; for (int64_t dim = start.first; dim < end.first; ++dim) { if (input_xla_shape->is_dynamic_dimension(dim)) { input_is_dynamic = true; } product = xla::Mul(product, xla::GetDimensionSize(input, dim)); + product_expr = (*product_expr * *input_shape.get_expression(dim))->s(); } bool unknown_dim_in_group = false; // The real size for the -1 dimension in a reshape. E.g., in // reshape(Tensor([2, 3, 3]), [3, -1, 3]) this will be 2. xla::XlaOp unknown_dim_size = product; + xla::DynExpr* unknown_dim_expr = product_expr; for (int64_t dim = start.second; dim < end.second; ++dim) { if (dim == unknown_index) { unknown_dim_in_group = true; } else { unknown_dim_size = xla::Div(unknown_dim_size, output_dim_sizes[dim]); + unknown_dim_expr = + (*unknown_dim_expr / *output_dim_exprs[dim])->s(); } } @@ -188,12 +252,13 @@ class ReshapeOp : public XlaOpKernel { // If input dim is dynamic, output dim at the -1 position must be // dynamic. Similarly, if input dim is static, output dim has to be // static at the -1 dimension. + output_dim_exprs[unknown_index] = unknown_dim_expr; dims_are_dynamic[unknown_index] = input_is_dynamic; output_dim_sizes[unknown_index] = unknown_dim_size; ctx->SetOutput( 0, xla::DynamicReshape(input, output_dim_sizes, shape.dim_sizes(), - dims_are_dynamic)); + dims_are_dynamic, output_dim_exprs)); VLOG(2) << "Reshape from " << ctx->InputXlaShape(0)->ToString() << " to " << xla::VectorString(shape.dim_sizes()) << ", dynamic_dims=" << xla::VectorString(dims_are_dynamic); diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc index 5cecbf37706283..0515dd3b785cfb 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc @@ -86,11 +86,17 @@ class ReverseSequenceOp : public XlaOpKernel { xla::XlaOp back = xla::Sub(seq_lens, xla::ScalarLike(seq_lens, 1)); xla::XlaOp batch_idx = xla::Iota( builder, - xla::ShapeUtil::MakeShape(seq_lens_type, {batch_size, max_seq_len, 1}), + xla::ShapeUtil::MakeShape(seq_lens_type, {batch_size, max_seq_len, 1}, + {input_shape.get_expression(batch_dim_), + input_shape.get_expression(seq_dim_), + xla::DynExpr::one}), /*iota_dimension=*/0); xla::XlaOp forward_idx = xla::Iota( builder, - xla::ShapeUtil::MakeShape(seq_lens_type, {batch_size, max_seq_len, 1}), + xla::ShapeUtil::MakeShape(seq_lens_type, {batch_size, max_seq_len, 1}, + {input_shape.get_expression(batch_dim_), + input_shape.get_expression(seq_dim_), + xla::DynExpr::one}), /*iota_dimension=*/1); xla::XlaOp reverse_idx = xla::Sub(back, forward_idx, {0}); reverse_idx = xla::Select(xla::Lt(reverse_idx, xla::ZerosLike(reverse_idx)), diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc index 694b4eb17ef298..95f52049872710 100644 --- a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc @@ -118,7 +118,8 @@ class ScatterNdOp : public XlaOpKernel { xla::XlaBuilder* builder = context->builder(); auto buffer = xla::Broadcast(XlaHelpers::Zero(builder, dtype), - buffer_shape.dim_sizes()); + buffer_shape.dim_sizes(), + buffer_shape.get_expressions()); auto indices = context->Input(0); auto updates = context->Input(1); auto combine = diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc index 21eaac25f058ed..faf6822b9a074d 100644 --- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc @@ -96,7 +96,8 @@ class SegmentReduce : public XlaOpKernel { buffer_shape.InsertDim(0, num_segments); auto buffer = - xla::Broadcast(InitialValue(builder), buffer_shape.dim_sizes()); + xla::Broadcast(InitialValue(builder), buffer_shape.dim_sizes(), + buffer_shape.get_expressions()); // Build dynamic dim sizes for buffer, as well as whether each dimension // size is dynamic or static. We build two parts: num_sgement part and diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc index 85aaabe87076c2..3241dfc61609a1 100644 --- a/tensorflow/compiler/tf2xla/kernels/select_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc @@ -70,8 +70,10 @@ class SelectOp : public XlaOpKernel { // Broadcast into the dimensions on the right. std::vector broadcast_dimensions(cond_shape.dims()); absl::c_iota(broadcast_dimensions, 0); + cond_handle = xla::BroadcastInDim(cond_handle, then_shape.dim_sizes(), - broadcast_dimensions); + broadcast_dimensions, + then_shape.get_expressions()); } ctx->SetOutput(0, xla::Select(cond_handle, then_handle, else_handle)); } @@ -81,7 +83,8 @@ class SelectOp : public XlaOpKernel { void operator=(const SelectOp&) = delete; }; -REGISTER_XLA_OP(Name("Select"), MlirXlaOpKernel); +// REGISTER_XLA_OP(Name("Select"), MlirXlaOpKernel); +REGISTER_XLA_OP(Name("Select"), SelectOp); class SelectOpV2 : public XlaOpKernel { public: diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 7e8889cb2ccee6..5d37b2f4283cf8 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -290,6 +290,8 @@ class ExpandDimsOp : public XlaOpKernel { " dimensions.")); auto existing_dims = input_shape.dim_sizes(); + auto existing_exprs = input_shape.get_expressions(); + // Safe - # elements in tensor dims bounded. const int existing_dims_size = static_cast(existing_dims.size()); std::vector new_shape(existing_dims_size); @@ -297,6 +299,12 @@ class ExpandDimsOp : public XlaOpKernel { new_shape[i] = existing_dims[i]; } + const int existing_exprs_size = static_cast(existing_exprs.size()); + std::vector new_exprs(existing_exprs_size); + for (size_t i = 0; i < new_exprs.size(); ++i) { + new_exprs[i] = existing_exprs[i]; + } + // We emulate numpy's interpretation of the dim axis when // -input.dims() >= dim <= input.dims(). if (dim < 0) { @@ -306,8 +314,9 @@ class ExpandDimsOp : public XlaOpKernel { // Clamp to the end if needed. dim = std::min(dim, existing_dims_size); new_shape.emplace(new_shape.begin() + dim, 1); + new_exprs.emplace(new_exprs.begin() + dim, xla::DynExpr::one); - ctx->SetOutput(0, xla::Reshape(ctx->Input("input"), new_shape)); + ctx->SetOutput(0, xla::Reshape(ctx->Input("input"), new_shape, new_exprs)); } }; REGISTER_XLA_OP(Name("ExpandDims").CompileTimeConstantInput("dim"), @@ -331,6 +340,7 @@ class SqueezeOp : public XlaOpKernel { absl::flat_hash_set wrapped_squeeze_dims; wrapped_squeeze_dims.reserve(squeeze_dims_.size()); std::vector new_shape; + std::vector new_exprs; // Validate squeeze dims against the input. for (int32_t dim : squeeze_dims_) { OP_REQUIRES( @@ -358,6 +368,7 @@ class SqueezeOp : public XlaOpKernel { } else { // This dimension is not being squeezed. new_shape.push_back(existing_dim); + new_exprs.push_back(shape.expressions(i)); } } else { OP_REQUIRES( @@ -368,11 +379,12 @@ class SqueezeOp : public XlaOpKernel { // Copy over all non-1-length dimensions. if (existing_dim != 1) { new_shape.push_back(existing_dim); + new_exprs.push_back(shape.expressions(i)); } } } - ctx->SetOutput(0, xla::Reshape(ctx->Input(0), new_shape)); + ctx->SetOutput(0, xla::Reshape(ctx->Input(0), new_shape, new_exprs)); } private: @@ -430,7 +442,8 @@ class ZerosLikeOp : public XlaOpKernel { auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0)); xla::XlaOp input = ctx->Input(0); auto input_shape = ctx->InputXlaShape(0).value(); - auto result = xla::Broadcast(zero, input_shape.dimensions()); + auto result = xla::Broadcast(zero, input_shape.dimensions(), + input_shape.expressions()); // Setting up dynamic dimensions of the broadcast. for (int64_t i = 0; i < input_shape.dimensions().size(); ++i) { @@ -455,7 +468,8 @@ class OnesLikeOp : public XlaOpKernel { const TensorShape input_shape = ctx->InputShape(0); auto one = XlaHelpers::One(ctx->builder(), input_type(0)); - ctx->SetOutput(0, xla::Broadcast(one, input_shape.dim_sizes())); + ctx->SetOutput(0, xla::Broadcast(one, input_shape.dim_sizes(), + input_shape.get_expressions())); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index 844a31f97990fc..caffe874e9781c 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -66,13 +66,17 @@ class SliceOp : public XlaOpKernel { ctx->ConstantInputAsIntVector(2, &size).ok(); if (all_begins_are_constant && all_sizes_are_constant) { std::vector wrapped_size(size.size()); + std::vector wrapped_size_exprs(size.size()); // `begin` is a compile-time constant. for (int i = 0; i < input_dims; ++i) { if (size[i] == -1) { // A size[i] of -1 means "all elements from begin[i] to dim_size(i)". wrapped_size[i] = input_shape.dim_size(i) - begin[i]; + wrapped_size_exprs[i] = + (*input_shape.get_expression(i) - begin[i])->s(); } else { wrapped_size[i] = size[i]; + wrapped_size_exprs[i] = xla::DynExpr::_(size[i]); } } @@ -97,13 +101,21 @@ class SliceOp : public XlaOpKernel { } } + std::vector begin_exprs; + for (int d : begin){ + begin_exprs.push_back(xla::DynExpr::_(d)); + } std::vector limits; + std::vector exprs; limits.reserve(begin.size()); + exprs.reserve(begin.size()); for (int i = 0; i < begin.size(); ++i) { limits.push_back(begin[i] + wrapped_size[i]); + exprs.push_back((*begin_exprs[i] + *wrapped_size_exprs[i])->s()); } std::vector strides(begin.size(), 1); - auto slice = xla::Slice(ctx->Input(0), begin, limits, strides); + auto slice = + xla::Slice(ctx->Input(0), begin, limits, begin_exprs, exprs, strides); // Check for slice on dynamic dimensions. std::vector size_is_dynamic; OP_REQUIRES_OK( @@ -114,8 +126,10 @@ class SliceOp : public XlaOpKernel { if (size[i] != -1) { // If there is a dynamic dimension, properly set dimension size of // the slice. - auto dynamic_size = - xla::Reshape(xla::Slice(ctx->Input(2), {i}, {i + 1}, {1}), {}); + auto dynamic_size = xla::Reshape( + xla::Slice(ctx->Input(2), {i}, {i + 1}, {xla::DynExpr::_(i)}, + {xla::DynExpr::_(i + 1)}, {1}), + {}); slice = xla::SetDimensionSize(slice, dynamic_size, i); } @@ -153,7 +167,13 @@ class SliceOp : public XlaOpKernel { } if (all_sizes_are_constant && !constant_size_is_minus_one) { xla::XlaOp input = ctx->Input(0); - ctx->SetOutput(0, xla::DynamicSlice(input, begin_indices, size)); + std::vector output_exprs; + output_exprs.reserve(size.size()); + for (int64_t d : size) { + output_exprs.push_back(xla::DynExpr::_(d)); + } + ctx->SetOutput( + 0, xla::DynamicSlice(input, begin_indices, size, output_exprs)); } else { // Size is not constant, use input size as upperbound and then set // dimension size on it. diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc index 330479bc8d4150..b1d7e16bbda741 100644 --- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc @@ -35,11 +35,72 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/bcast.h" +#include "tensorflow/compiler/tf2xla/type_util.h" + namespace tensorflow { namespace { -REGISTER_XLA_OP(Name("Softmax"), MlirXlaOpKernel); -REGISTER_XLA_OP(Name("LogSoftmax"), MlirXlaOpKernel); +class SoftmaxOp : public XlaOpKernel { + public: + explicit SoftmaxOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + log_ = absl::StartsWith(type_string(), "Log"); + } + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape logits_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(logits_shape), + errors::InvalidArgument("logits must have >= 1 dimension, got ", + logits_shape.DebugString())); + + // Major dimensions are batch dimensions, minor dimension is the class + // dimension. + std::vector batch_dims(logits_shape.dims() - 1); + std::iota(batch_dims.begin(), batch_dims.end(), 0); + const int kClassDim = logits_shape.dims() - 1; + + const DataType type = input_type(0); + const xla::PrimitiveType xla_type = ctx->input_xla_type(0); + auto logits = ctx->Input(0); + + xla::XlaBuilder* const b = ctx->builder(); + + const xla::XlaComputation& max_func = *ctx->GetOrCreateMax(type); + + // Find the max in each batch, resulting in a tensor of shape [batch] + auto logits_max = + xla::Reduce(logits, xla::MinValue(b, xla_type), max_func, {kClassDim}); + // Subtract the max in batch b from every element in batch b. Broadcasts + // along the batch dimension. + auto shifted_logits = xla::Sub(logits, logits_max, batch_dims); + auto exp_shifted = xla::Exp(shifted_logits); + const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); + xla::PrimitiveType xla_accumulation_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(accumulation_type, + &xla_accumulation_type)); + auto converted = + xla::ConvertElementType(exp_shifted, xla_accumulation_type); + auto reduce = + xla::Reduce(converted, xla::Zero(b, xla_accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); + auto sum = XlaHelpers::ConvertElementType(reduce, type); + auto softmax = + log_ + // softmax = shifted_logits - log(sum(exp(shifted_logits))) + ? xla::Sub(shifted_logits, xla::Log(sum), batch_dims) + // softmax = exp(shifted_logits) / sum(exp(shifted_logits)) + : xla::Div(exp_shifted, sum, batch_dims); + ctx->SetOutput(0, softmax); + } + + private: + bool log_; +}; + +REGISTER_XLA_OP(Name("Softmax"), SoftmaxOp); +REGISTER_XLA_OP(Name("LogSoftmax"), SoftmaxOp); std::pair CrossEntropyWithLogits( XlaOpKernelContext* ctx, DataType type, xla::PrimitiveType xla_type, diff --git a/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc b/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc index b4d589f183108e..1368326cf851df 100644 --- a/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc @@ -83,7 +83,8 @@ class SparseToDenseOp : public XlaOpKernel { sparse_values = Broadcast(sparse_values, {num_elems}); } xla::XlaBuilder* builder = context->builder(); - auto buffer = Broadcast(default_value, output_shape.dim_sizes()); + auto buffer = Broadcast(default_value, output_shape.dim_sizes(), + output_shape.get_expressions()); std::vector dynamic_dims; OP_REQUIRES_OK( context, context->ResolveInputDynamismIntoPredVector(1, &dynamic_dims)); diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index 4f7c4ae99b6b6b..bf0c294e21f003 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -81,17 +81,21 @@ class SplitOp : public XlaOpKernel { // All the slices are the same size: this is the size along the // split dimension. const int32_t slice_size = input_shape.dim_size(split_dim) / num_split; + auto slice_expr = *input_shape.get_expression(split_dim) / num_split; // The vectors we will use to define the slice. The entry for the // split dimensions varies for each output. std::vector begin(input_shape.dims(), 0); std::vector limits(input_shape.dims()); + std::vector begin_expr(input_shape.dims(), xla::DynExpr::zero); + std::vector limits_expr(input_shape.dims()); std::vector strides(input_shape.dims(), 1); for (int i = 0; i < input_shape.dims(); ++i) { // Initially set up the limits to be the full size of the input: // the split dimension is filled in below. int64_t dim = input_shape.dim_size(i); limits[i] = dim; + limits_expr[i] = input_shape.get_expression(i); } // Create each of the outputs. @@ -99,7 +103,12 @@ class SplitOp : public XlaOpKernel { // Slice out the ith split from the split dimension. begin[split_dim] = i * slice_size; limits[split_dim] = (i + 1) * slice_size; - ctx->SetOutput(i, xla::Slice(input, begin, limits, strides)); + + begin_expr[split_dim] = i * *slice_expr; + limits_expr[split_dim] = (*xla::DynExpr::_(i + 1) * *slice_expr)->s(); + + ctx->SetOutput(i, xla::Slice(input, begin, limits, begin_expr, + limits_expr, strides)); } } }; @@ -202,21 +211,28 @@ class SplitVOp : public XlaOpKernel { input_shape.dim_size(split_dim) - total_split_size; } - // The vectors we will use to define the slice. The entry for the - // split dimensions varies for each output. + // The vectors we will use to define the slice. The entry for the split + // dimension varies for each output. std::vector begin(input_shape.dims(), 0); auto dim_sizes = input_shape.dim_sizes(); std::vector limits(dim_sizes.begin(), dim_sizes.end()); std::vector strides(input_shape.dims(), 1); + std::vector begin_expr(input_shape.dims(), + xla::DynExpr::zero); + auto input_exprs = input_shape.get_expressions(); + std::vector limits_expr(input_exprs.begin(), + input_exprs.end()); for (int i = 0; i < num_split; ++i) { - TensorShape output_shape(input_shape); int slice_size = split_sizes[i]; - output_shape.set_dim(split_dim, slice_size); + xla::DynExpr* slice_expr = xla::DynExpr::_(slice_size); // Slice out the ith split from the split dimension. limits[split_dim] = begin[split_dim] + slice_size; - ctx->SetOutput(i, xla::Slice(input, begin, limits, strides)); + limits_expr[split_dim] = (*begin_expr[split_dim] + *slice_expr)->s(); + ctx->SetOutput( + i, xla::Slice(input, begin, limits, begin_expr, limits_expr, strides)); begin[split_dim] = limits[split_dim]; + begin_expr[split_dim] = limits_expr[split_dim]; } } }; diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index 3c99ad63565266..59716591ad08bc 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -157,7 +157,8 @@ class StackPushOp : public XlaOpKernel { TensorShape slice_shape = elem_shape; slice_shape.InsertDim(0, 1LL); - auto update = xla::Reshape(value, slice_shape.dim_sizes()); + auto update = xla::Reshape(value, slice_shape.dim_sizes(), + slice_shape.get_expressions()); // TODO(phawkins): We don't check the index is in bounds --- there is no // error mechanism in XLA. diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index aa71c5c34d2e1a..2f19d1f2dd9a03 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -49,8 +49,9 @@ xla::BitGeneratorTy GetBitGeneratorForDevice( device_type_string == DEVICE_CPU_XLA_JIT) { return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) { std::tie(state, key) = xla::ScramblePhiloxKey(key); - xla::XlaOp philox_state = - xla::ConcatInDim(key.builder(), {xla::Reshape(key, {1}), state}, 0); + xla::XlaOp philox_state = xla::ConcatInDim( + key.builder(), + {xla::Reshape(key, {1}, {xla::DynExpr::one}), state}, 0); xla::XlaOp result = xla::RngBitGenerator(xla::RandomAlgorithm::RNG_PHILOX, philox_state, shape); return xla::RngOutput{/*value=*/xla::GetTupleElement(result, 1), @@ -421,19 +422,23 @@ class StatelessParameterizedTruncatedNormalOp : public XlaOpKernel { xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape)); - auto bcasted_means = BroadcastTo(ctx->Input(2), shape.dim_sizes()); + auto bcasted_means = + BroadcastTo(ctx->Input(2), shape.dim_sizes(), shape.get_expressions()); OP_REQUIRES_OK(ctx, bcasted_means.status()); auto means = bcasted_means.value(); - auto bcasted_stddevs = BroadcastTo(ctx->Input(3), shape.dim_sizes()); + auto bcasted_stddevs = + BroadcastTo(ctx->Input(3), shape.dim_sizes(), shape.get_expressions()); OP_REQUIRES_OK(ctx, bcasted_stddevs.status()); auto stddevs = bcasted_stddevs.value(); - auto bcasted_minvals = BroadcastTo(ctx->Input(4), shape.dim_sizes()); + auto bcasted_minvals = + BroadcastTo(ctx->Input(4), shape.dim_sizes(), shape.get_expressions()); OP_REQUIRES_OK(ctx, bcasted_minvals.status()); auto minvals = bcasted_minvals.value(); - auto bcasted_maxvals = BroadcastTo(ctx->Input(5), shape.dim_sizes()); + auto bcasted_maxvals = + BroadcastTo(ctx->Input(5), shape.dim_sizes(), shape.get_expressions()); OP_REQUIRES_OK(ctx, bcasted_maxvals.status()); auto maxvals = bcasted_maxvals.value(); diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index e15196bd756462..df7deaaf80d3bf 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -82,6 +82,9 @@ class StridedSliceOp : public XlaOpKernel { partial_final_shape.set_dim( i, input_shape.dim_size(shape_spec.output_to_processing_mapping[i])); + partial_final_shape.set_expression( + i, input_shape.get_expression( + shape_spec.output_to_processing_mapping[i])); } } @@ -97,6 +100,8 @@ class StridedSliceOp : public XlaOpKernel { // Use input shape to update unknown dimension of partial shape -- if a // dimension is unknown, we use input shape as bound. partial_processing_shape.set_dim(i, input_shape.dim_size(i)); + partial_processing_shape.set_expression(i, + input_shape.get_expression(i)); } } TensorShape processing_shape; @@ -157,7 +162,10 @@ class StridedSliceOp : public XlaOpKernel { auto zero = xla::Zero(ctx->builder(), ctx->InputXlaType("begin")); xla::XlaOp begin_index, end_index; int64_t sparse_index = shape_spec.processing_to_sparse_mapping[i]; - bool xla_input_is_dynamic = input_xla_shape.is_dynamic_dimension(i); + xla::DynExpr* input_expr = input_xla_shape.expressions(i); + bool xla_input_is_dynamic = + input_xla_shape.is_dynamic_dimension(i) || + (input_expr != nullptr && input_expr->is_dynamic()); xla::XlaOp dim_size; if (xla_input_is_dynamic) { dim_size = xla::GetDimensionSize(ctx->Input(0), i); @@ -215,10 +223,12 @@ class StridedSliceOp : public XlaOpKernel { } slice = - xla::DynamicSlice(slice, start_indices, processing_shape.dim_sizes()); + xla::DynamicSlice(slice, start_indices, processing_shape.dim_sizes(), + processing_shape.get_expressions()); // new_axis_mask_, ellipsis_mask_ and shrink_axis_mask_ may add or remove // size 1 dims of a shape. - slice = xla::Reshape(slice, final_shape.dim_sizes()); + slice = xla::Reshape(slice, final_shape.dim_sizes(), + final_shape.get_expressions()); for (int64_t i = 0; i < final_shape.dims(); ++i) { int64 processing_shape_dim = shape_spec.output_to_processing_mapping[i]; // If processing_shape_dim is -1, it means the output dimension was newly @@ -246,6 +256,8 @@ class StridedSliceOp : public XlaOpKernel { absl::InlinedVector begin; absl::InlinedVector end; + absl::InlinedVector begin_expr; + absl::InlinedVector end_expr; absl::InlinedVector strides; xla::Literal begin_literal, end_literal, strides_literal; @@ -268,14 +280,15 @@ class StridedSliceOp : public XlaOpKernel { PartialTensorShape partial_processing_shape, partial_final_shape; bool dummy = false; StridedSliceShapeSpec shape_spec; + OP_REQUIRES_OK( - ctx, - ValidateStridedSliceOp( - begin_is_constant ? &begin_tensor : nullptr, - end_is_constant ? &end_tensor : nullptr, strides_tensor, - input_shape, begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_, - shrink_axis_mask_, &partial_processing_shape, &partial_final_shape, - &dummy, &dummy, &dummy, &begin, &end, &strides, &shape_spec)); + ctx, ValidateStridedSliceOp( + begin_is_constant ? &begin_tensor : nullptr, + end_is_constant ? &end_tensor : nullptr, strides_tensor, + input_shape, begin_mask_, end_mask_, ellipsis_mask_, + new_axis_mask_, shrink_axis_mask_, &partial_processing_shape, + &partial_final_shape, &dummy, &dummy, &dummy, &begin, &end, + &strides, &begin_expr, &end_expr, &shape_spec)); xla::XlaOp slice = ctx->Input(0); std::vector begins_are_dynamic; @@ -294,17 +307,28 @@ class StridedSliceOp : public XlaOpKernel { ", output shape must be a compile-time constant")); absl::InlinedVector dimensions_to_reverse; absl::InlinedVector slice_begin, slice_end, slice_strides; + absl::InlinedVector slice_begin_expr, slice_end_expr; for (int i = 0; i < begin.size(); ++i) { if (strides[i] > 0) { slice_begin.push_back(begin[i]); + slice_begin_expr.push_back(begin_expr[i]); slice_end.push_back(std::max(end[i], begin[i])); + slice_end_expr.push_back((end[i] > begin[i]) ? end_expr[i] + : begin_expr[i]); slice_strides.push_back(strides[i]); } else { // Negative stride: swap begin and end, add 1 because the interval // is semi-open, and mark the dimension to be reversed. + auto input_exprs = input_shape.get_expressions(); slice_begin.push_back(input_shape.dim_size(i) - begin[i] - 1); + slice_begin_expr.push_back( + (*input_exprs[i] - *begin_expr[i] - 1)->s()); slice_end.push_back(std::max(input_shape.dim_size(i) - end[i] - 1, input_shape.dim_size(i) - begin[i] - 1)); + slice_end_expr.push_back( + (end[i] < begin[i]) + ? (*input_exprs[i] - *end_expr[i] - 1)->s() + : (*input_exprs[i] - *begin_expr[i] - 1)->s()); slice_strides.push_back(-strides[i]); dimensions_to_reverse.push_back(i); } @@ -312,7 +336,8 @@ class StridedSliceOp : public XlaOpKernel { if (!dimensions_to_reverse.empty()) { slice = xla::Rev(slice, dimensions_to_reverse); } - slice = xla::Slice(slice, slice_begin, slice_end, slice_strides); + slice = xla::Slice(slice, slice_begin, slice_end, slice_begin_expr, + slice_end_expr, slice_strides); auto operand_shape_or = ctx->builder()->GetShape(ctx->Input(0)); OP_REQUIRES_OK(ctx, operand_shape_or.status()); xla::Shape xla_shape = operand_shape_or.value(); @@ -325,7 +350,8 @@ class StridedSliceOp : public XlaOpKernel { bool ends_are_static = absl::c_all_of( ends_are_dynamic, [](bool dynamic) { return !dynamic; }); // Static output shape, return a static slice. - slice = xla::Reshape(slice, final_shape.dim_sizes()); + slice = xla::Reshape(slice, final_shape.dim_sizes(), + final_shape.get_expressions()); if (xla_shape.is_static() && ends_are_static) { ctx->SetOutput(0, slice); return; @@ -436,6 +462,8 @@ class StridedSliceGradOp : public XlaOpKernel { PartialTensorShape processing_shape, final_shape; absl::InlinedVector begin; absl::InlinedVector end; + absl::InlinedVector begin_expr; + absl::InlinedVector end_expr; absl::InlinedVector strides; StridedSliceShapeSpec shape_spec; OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_, @@ -445,7 +473,7 @@ class StridedSliceGradOp : public XlaOpKernel { nullptr, nullptr, strides_tensor, input_shape, begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_, shrink_axis_mask_, &processing_shape, &final_shape, &dummy, &dummy, &dummy, - &begin, &end, &strides, &shape_spec)); + &begin, &end, &strides, &begin_expr, &end_expr, &shape_spec)); for (int64_t i = 0; i < processing_shape.dims(); ++i) { OP_REQUIRES( ctx, strides[i] == 1, @@ -459,14 +487,17 @@ class StridedSliceGradOp : public XlaOpKernel { VLOG(1) << "xla final_shape" << final_shape; VLOG(1) << "input_shape" << input_shape.DebugString(); auto input_sizes = input_shape.dim_sizes(); + auto input_exprs = input_shape.get_expressions(); // For unknown output dim the bound of the output shape is input. Pad and // double the size of input shape to leave enough buffer to avoid OOB // dynamic update slice. auto input_sizes_padded = input_shape.dim_sizes(); + auto input_exprs_padded = input_shape.get_expressions(); bool need_padding = false; for (int64_t i = 0; i < processing_shape.dims(); ++i) { if (processing_shape.dim_size(i) == -1) { input_sizes_padded[i] *= 2; + input_exprs_padded[i] = (2 * *input_exprs_padded[i])->s(); need_padding = true; } } @@ -477,6 +508,7 @@ class StridedSliceGradOp : public XlaOpKernel { if (shape_spec.output_to_processing_mapping[i] != -1) { processing_shape.set_dim(shape_spec.output_to_processing_mapping[i], grad_shape.dimensions(i)); + // TODO Pass it back } } @@ -506,15 +538,19 @@ class StridedSliceGradOp : public XlaOpKernel { } auto zero = XlaHelpers::Zero(ctx->builder(), ctx->expected_output_dtype(0)); - zero = xla::Broadcast(zero, input_sizes_padded); - grad = xla::Reshape(grad, processing_shape.dim_sizes()); + zero = xla::Broadcast(zero, input_sizes_padded, input_exprs_padded); + grad = xla::Reshape(grad, processing_shape.dim_sizes(), + processing_shape.get_expressions()); grad = xla::DynamicUpdateSlice(zero, grad, begins); if (need_padding) { // We padded the input shape to avoid OOB when DUS. Now slice out the // padding in the final result. std::vector strides(input_shape.dims(), 1); std::vector start_indices(input_shape.dims(), 0); - grad = xla::Slice(grad, start_indices, input_sizes, strides); + std::vector start_exprs(input_shape.dims(), + xla::DynExpr::zero); + grad = xla::Slice(grad, start_indices, input_sizes, start_exprs, + input_exprs, strides); } ctx->SetOutput(0, grad); } @@ -522,6 +558,8 @@ class StridedSliceGradOp : public XlaOpKernel { TensorShape processing_shape, final_shape; absl::InlinedVector begin; absl::InlinedVector end; + absl::InlinedVector begin_expr; + absl::InlinedVector end_expr; absl::InlinedVector strides; TensorShape input_shape; @@ -547,11 +585,12 @@ class StridedSliceGradOp : public XlaOpKernel { bool dummy = false; OP_REQUIRES_OK( - ctx, ValidateStridedSliceOp( - &begin_tensor, &end_tensor, strides_tensor, input_shape, - begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_, - shrink_axis_mask_, &processing_shape, &final_shape, &dummy, - &dummy, &dummy, &begin, &end, &strides)); + ctx, + ValidateStridedSliceOp( + &begin_tensor, &end_tensor, strides_tensor, input_shape, + begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_, + shrink_axis_mask_, &processing_shape, &final_shape, &dummy, &dummy, + &dummy, &begin, &end, &strides, &begin_expr, &end_expr)); // Check to make sure dy is consistent with the original slice const TensorShape dy_shape = ctx->InputShape(4); @@ -570,7 +609,8 @@ class StridedSliceGradOp : public XlaOpKernel { xla::XlaOp grad = ctx->Input(4); // Undo any new/shrink axes. - grad = xla::Reshape(grad, processing_shape.dim_sizes()); + grad = xla::Reshape(grad, processing_shape.dim_sizes(), + processing_shape.get_expressions()); // Pad the input gradients. absl::InlinedVector dimensions_to_reverse; @@ -662,6 +702,8 @@ class StridedSliceAssignOp : public XlaOpKernel { TensorShape final_shape; absl::InlinedVector begin; absl::InlinedVector end; + absl::InlinedVector begin_expr; + absl::InlinedVector end_expr; absl::InlinedVector strides; xla::Literal begin_literal, end_literal, strides_literal; @@ -690,12 +732,13 @@ class StridedSliceAssignOp : public XlaOpKernel { TensorShape dummy_processing_shape; bool dummy = false; - OP_REQUIRES_OK(ctx, - ValidateStridedSliceOp( - &begin_tensor, &end_tensor, strides_tensor, lhs_shape, - begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_, - shrink_axis_mask_, &dummy_processing_shape, &final_shape, - &dummy, &dummy, &dummy, &begin, &end, &strides)); + OP_REQUIRES_OK( + ctx, + ValidateStridedSliceOp( + &begin_tensor, &end_tensor, strides_tensor, lhs_shape, begin_mask_, + end_mask_, ellipsis_mask_, new_axis_mask_, shrink_axis_mask_, + &dummy_processing_shape, &final_shape, &dummy, &dummy, &dummy, + &begin, &end, &strides, &begin_expr, &end_expr)); if (final_shape.num_elements() == 0 && rhs_shape.num_elements() == 0) { // DynamicUpdateSlice does not allow 0-element updates. We should probably @@ -717,6 +760,7 @@ class StridedSliceAssignOp : public XlaOpKernel { absl::InlinedVector dimensions_to_reverse; absl::InlinedVector slice_begin; absl::InlinedVector slice_dims; + absl::InlinedVector slice_exprs; for (int i = 0; i < begin.size(); ++i) { // TODO(b/121179231): implement strides != 1 OP_REQUIRES( @@ -726,12 +770,14 @@ class StridedSliceAssignOp : public XlaOpKernel { slice_begin.push_back( xla::ConstantR0(ctx->builder(), begin[i])); slice_dims.push_back(end[i] - begin[i]); + slice_exprs.push_back(xla::DynExpr::_(end[i] - begin[i])); } else { // Negative stride: swap begin and end, add 1 because the interval // is semi-open, and mark the dimension to be reversed. slice_begin.push_back( xla::ConstantR0(ctx->builder(), end[i] + 1)); slice_dims.push_back(begin[i] - end[i]); + slice_exprs.push_back(xla::DynExpr::_(begin[i] - end[i])); dimensions_to_reverse.push_back(i); } } @@ -739,7 +785,7 @@ class StridedSliceAssignOp : public XlaOpKernel { if (!dimensions_to_reverse.empty()) { rhs = xla::Rev(rhs, dimensions_to_reverse); } - rhs = xla::Reshape(rhs, slice_dims); + rhs = xla::Reshape(rhs, slice_dims, slice_exprs); lhs = xla::DynamicUpdateSlice(lhs, rhs, slice_begin); diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 888908e30b2331..f436472195c383 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -165,9 +165,11 @@ class TensorArrayOp : public XlaOpKernel { CHECK(element_shape_.AsTensorShape(&shape)); TensorShape ta_shape; ta_shape.AddDim(size); + ta_shape.AddExpression(xla::DynExpr::_(size)); ta_shape.AppendShape(shape); xla::XlaOp zero = XlaHelpers::Zero(b, dtype_); - value = xla::Broadcast(zero, ta_shape.dim_sizes()); + value = xla::Broadcast(zero, ta_shape.dim_sizes(), + ta_shape.get_expressions()); } XlaResource* var = @@ -223,7 +225,8 @@ class TensorArrayWriteOp : public XlaOpKernel { TensorShape slice_shape = elem_shape; slice_shape.InsertDim(0, 1LL); - auto update = xla::Reshape(value, slice_shape.dim_sizes()); + auto update = xla::Reshape(value, slice_shape.dim_sizes(), + slice_shape.get_expressions()); xla::XlaOp written; if (resource->tensor_array_multiple_writes_aggregate()) { @@ -274,9 +277,12 @@ class TensorArrayReadOp : public XlaOpKernel { start_indices[0] = index; auto slice_shape = ta_shape.dim_sizes(); + auto slice_exprs = ta_shape.get_expressions(); slice_shape[0] = 1LL; + slice_exprs[0] = xla::DynExpr::_(1LL); - xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape); + xla::XlaOp read = + xla::DynamicSlice(ta, start_indices, slice_shape, slice_exprs); // Remove the leading '1' dimension. std::vector value_shape(slice_shape.begin() + 1, @@ -469,9 +475,12 @@ class TensorArrayConcatOp : public XlaOpKernel { xla::XlaOp ta = resource->value(); auto ta_dims = ta_shape.dim_sizes(); + auto ta_exprs = ta_shape.get_expressions(); std::vector shape(ta_dims.begin() + 1, ta_dims.end()); + std::vector exprs(ta_exprs.begin() + 1, ta_exprs.end()); shape[0] *= ta_shape.dim_size(0); - ctx->SetOutput(0, xla::Reshape(ta, shape)); + exprs[0] = *ta_exprs[0] * *ta_shape.get_expression(0); + ctx->SetOutput(0, xla::Reshape(ta, shape, exprs)); Tensor lengths(DT_INT64, {ta_dims[0]}); auto lengths_vec = lengths.vec(); @@ -526,6 +535,7 @@ class TensorArraySplitOp : public XlaOpKernel { TensorShape ta_shape; ta_shape.AddDim(resource->max_array_size()); + ta_shape.AddExpression(xla::DynExpr::_(resource->max_array_size())); ta_shape.AppendShape(elem_shape); OP_REQUIRES(ctx, lengths.size() == resource->max_array_size(), @@ -541,7 +551,8 @@ class TensorArraySplitOp : public XlaOpKernel { value_shape.DebugString(), " vs. ", ta_shape.DebugString())); - const xla::XlaOp reshape = xla::Reshape(value, ta_shape.dim_sizes()); + const xla::XlaOp reshape = + xla::Reshape(value, ta_shape.dim_sizes(), ta_shape.get_expressions()); if (dtype_ == DT_BOOL) { ta = xla::Or(ta, reshape); } else { diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc index a1f58d5ae9b40e..ae0f27a4fad59f 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -131,7 +131,8 @@ absl::Status TryGetElementShapeFromInput(XlaOpKernelContext* ctx, return absl::OkStatus(); } - *shape = xla::ShapeUtil::MakeShape(dtype, partial_shape.dim_sizes()); + *shape = xla::ShapeUtil::MakeShape(dtype, partial_shape.dim_sizes(), + partial_shape.get_expressions()); *got_shape = true; return absl::OkStatus(); } @@ -503,6 +504,8 @@ class TensorListConcatOp : public XlaOpKernel { xla::Shape element_shape = std::move(shape_or).value(); std::vector element_dims = xla::SpanToVector(element_shape.dimensions()); + std::vector element_exprs = + xla::SpanToVector(element_shape.expressions()); OP_REQUIRES( ctx, element_dims.size() > 1, errors::Unimplemented("TensorList of scalars is not supported")); @@ -510,12 +513,15 @@ class TensorListConcatOp : public XlaOpKernel { int64_t tensor_lengths = element_dims[1]; std::vector new_dims = {num_elements * tensor_lengths}; + std::vector new_exprs = { + xla::DynExpr::_(num_elements * tensor_lengths)}; for (int i = 2; i < element_dims.size(); i++) { new_dims.push_back(element_dims[i]); + new_exprs.push_back(element_exprs[i]); } - xla::XlaOp out = xla::Reshape(buffer, new_dims); + xla::XlaOp out = xla::Reshape(buffer, new_dims, new_exprs); ctx->SetOutput(0, out); // Second output is a tensor of lengths of returned tensors. @@ -550,6 +556,8 @@ class TensorListSplitOp : public XlaOpKernel { xla::Shape element_shape = std::move(shape_or).value(); std::vector element_dims = xla::SpanToVector(element_shape.dimensions()); + std::vector element_exprs = + xla::SpanToVector(element_shape.expressions()); OP_REQUIRES( ctx, !element_dims.empty(), errors::Unimplemented("Element dimensions have to be non-empty")); @@ -569,11 +577,13 @@ class TensorListSplitOp : public XlaOpKernel { ctx, element_dims[0] % length == 0, errors::Unimplemented("Buffer size has to be a multiple of length")); std::vector new_dims = {element_dims[0] / length, length}; + std::vector new_exprs = {*element_exprs[0] / length, + xla::DynExpr::_(length)}; for (int i = 1; i < element_dims.size(); i++) { new_dims.push_back(element_dims[i]); } - xla::XlaOp reshaped = xla::Reshape(input_tensor, new_dims); + xla::XlaOp reshaped = xla::Reshape(input_tensor, new_dims, new_exprs); xla::XlaOp result; OP_REQUIRES_OK(ctx, ExecuteTensorListFromTensor(length, reshaped, &result)); diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc index 683dc4737e6dab..1771181e440f31 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc @@ -389,7 +389,12 @@ absl::Status ExecuteTensorListPushBack(xla::XlaOp list, xla::XlaOp element, std::vector element_part_dims = xla::SpanToVector(element_part_shape.dimensions()); element_part_dims.insert(element_part_dims.begin(), 1); - element_part = xla::Reshape(element_part, element_part_dims); + std::vector element_part_exprs = + xla::SpanToVector(element_part_shape.expressions()); + element_part_exprs.insert(element_part_exprs.begin(), + xla::DynExpr::one); + element_part = + xla::Reshape(element_part, element_part_dims, element_part_exprs); std::vector start_indices( element_part_shape.dimensions().size() + 1, @@ -406,7 +411,10 @@ absl::Status ExecuteTensorListPushBack(xla::XlaOp list, xla::XlaOp element, std::vector element_dims = xla::SpanToVector(element_shape.dimensions()); element_dims.insert(element_dims.begin(), 1); - xla::XlaOp update = xla::Reshape(element, element_dims); + std::vector element_exprs = + xla::SpanToVector(element_shape.expressions()); + element_exprs.insert(element_exprs.begin(), xla::DynExpr::one); + xla::XlaOp update = xla::Reshape(element, element_dims, element_exprs); std::vector start_indices(element_shape.dimensions().size() + 1, xla::ConstantR0(b, 0)); @@ -455,11 +463,16 @@ absl::Status ExecuteTensorListPopBack(xla::XlaOp list, xla::XlaOp* list_result, xla::SpanToVector(list_part_shape.dimensions()); slice_shape[0] = 1LL; + std::vector slice_exprs = + xla::SpanToVector(list_part_shape.expressions()); + slice_exprs[0] = xla::DynExpr::_(1LL); + xla::XlaOp list_part = xla::GetTupleElement(list, i); xla::XlaOp read = xla::DynamicSlice(list_part, start_indices, slice_shape); slice_shape.erase(slice_shape.begin()); - element_result_parts.push_back(xla::Reshape(read, slice_shape)); + element_result_parts.push_back( + xla::Reshape(read, slice_shape, slice_exprs)); list_result_parts.push_back(list_part); } list_result_parts.push_back(push_index); @@ -493,7 +506,11 @@ absl::Status ExecuteTensorListSetItem(xla::XlaOp list, xla::XlaOp index, std::vector element_dims = xla::SpanToVector(element_shape.dimensions()); element_dims.insert(element_dims.begin(), 1); - xla::XlaOp update = xla::Reshape(element, element_dims); + std::vector element_exprs = + xla::SpanToVector(element_shape.expressions()); + element_exprs.insert(element_exprs.begin(), xla::DynExpr::one); + + xla::XlaOp update = xla::Reshape(element, element_dims, element_exprs); std::vector start_indices(element_shape.dimensions().size() + 1, xla::ConstantR0(b, 0)); @@ -557,6 +574,10 @@ absl::Status ExecuteTensorListGetItem(xla::XlaOp list, xla::XlaOp index, xla::SpanToVector(buffer_shape.dimensions()); slice_shape[0] = 1LL; + std::vector slice_exprs = + xla::SpanToVector(buffer_shape.expressions()); + slice_exprs[0] = xla::DynExpr::_(1LL); + xla::XlaOp list_part = xla::GetTupleElement(list, 0); xla::XlaOp read = xla::DynamicSlice(list_part, start_indices, slice_shape); // Propagate dynamic dimensions from buffer to the sliced buffer, except for @@ -569,7 +590,8 @@ absl::Status ExecuteTensorListGetItem(xla::XlaOp list, xla::XlaOp index, } } slice_shape.erase(slice_shape.begin()); - *result = xla::Reshape(read, slice_shape); + slice_exprs.erase(slice_exprs.begin()); + *result = xla::Reshape(read, slice_shape, slice_exprs); return absl::OkStatus(); } diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc index 6c39981ba5b937..3eab14bf78e968 100644 --- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc @@ -67,11 +67,17 @@ class TileOp : public XlaOpKernel { xla::ValueInferenceMode::kUpperBound)); std::vector output_dims(input_shape.dims()); + std::vector output_exprs(input_shape.dims()); + + auto expr_sizes = input_shape.get_expressions(); + for (int64_t i = 0; i < input_shape.dims(); ++i) { OP_REQUIRES(ctx, multiples_bounds[i] >= 0, errors::InvalidArgument("Expected multiples[", i, "] >= 0, but got ", output_dims[i])); output_dims[i] = input_shape.dim_size(i) * multiples_bounds[i]; + output_exprs[i] = + (*expr_sizes[i] * *xla::DynExpr::_(multiples_bounds[i]))->s(); } std::vector multiples_are_dynamic; @@ -91,8 +97,8 @@ class TileOp : public XlaOpKernel { return; } } - - auto result_or = BroadcastTo(ctx->Input("input"), output_dims); + auto result_or = + BroadcastTo(ctx->Input("input"), output_dims, output_exprs); OP_REQUIRES_OK(ctx, result_or.status()); auto result = result_or.value(); diff --git a/tensorflow/compiler/tf2xla/kernels/unique_op.cc b/tensorflow/compiler/tf2xla/kernels/unique_op.cc index 46de3dd89b6115..fbe181d13ef547 100644 --- a/tensorflow/compiler/tf2xla/kernels/unique_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/unique_op.cc @@ -83,9 +83,12 @@ class UniqueOpBase : public XlaOpKernel { // // This is implemented as an hlo while loop. xla::XlaOp RollingSelectR1(XlaOpKernelContext* ctx, xla::XlaOp data, - xla::XlaOp mask, int64_t size) { + xla::XlaOp mask, int64_t size, + xla::DynExpr* expr) { xla::XlaComputation cond, body; - const xla::Shape r1_shape = xla::ShapeUtil::MakeShape(xla::S32, {size}); + xla::Shape r1_shape = xla::ShapeUtil::MakeShape(xla::S32, {size}); + r1_shape.set_expression(0, expr); + const xla::Shape counter_shape = xla::ShapeUtil::MakeScalarShape(xla::S32); const xla::Shape& single_element_shape = counter_shape; @@ -136,7 +139,7 @@ class UniqueOpBase : public XlaOpKernel { } auto zero = xla::Zero(ctx->builder(), xla::S32); - auto zero_broadcast = xla::Broadcast(zero, {size}); + auto zero_broadcast = xla::Broadcast(zero, {size}, {expr}); auto init = xla::Tuple(ctx->builder(), {zero, data, mask, zero_broadcast}); return xla::GetTupleElement(xla::While(cond, body, init), 3); } @@ -153,13 +156,19 @@ class UniqueOpBase : public XlaOpKernel { auto aux = MoveAxis(input, axis, 0, input_shape); auto aux_shape = ctx->builder()->GetShape(aux).value(); int64_t leading_size = aux_shape.dimensions(0); + auto leading_expr = aux_shape.expressions(0); int64_t product = 1; + auto product_expr = xla::DynExpr::one; for (int64_t i = 1; i < aux_shape.dimensions().size(); ++i) { product *= aux_shape.dimensions(i); + product_expr = *(product_expr) * *(aux_shape.expressions(i)); } - aux = xla::Reshape(aux, {leading_size, product}); + product_expr = product_expr->s(); + aux = xla::Reshape(aux, {leading_size, product}, + {leading_expr, product_expr}); if (leading_size == 0) { - auto result_data = xla::Reshape(aux, aux_shape.dimensions()); + auto result_data = + xla::Reshape(aux, aux_shape.dimensions(), aux_shape.expressions()); result_data = MoveAxis(result_data, 0, axis, aux_shape); ctx->SetOutput(0, result_data); ctx->SetOutput(1, xla::Iota(ctx->builder(), xla::S32, leading_size)); @@ -171,10 +180,13 @@ class UniqueOpBase : public XlaOpKernel { sort_types.reserve(product + 1); for (int64_t i = 0; i < product; ++i) { xla::XlaOp slice = xla::SliceInDim(aux, i, i + 1, 1, 1); - sort_keys.push_back(xla::Reshape(slice, {leading_size})); + sort_keys.push_back(xla::Reshape(slice, {leading_size}, {leading_expr})); sort_types.push_back(input_shape.element_type()); } - auto iota = xla::Iota(ctx->builder(), xla::S32, leading_size); + xla::Shape iota_shape = + xla::ShapeUtil::MakeShape(xla::S32, {leading_size}, {leading_expr}); + iota_shape.set_expression(0, leading_expr); + auto iota = xla::Iota(ctx->builder(), iota_shape, 0); sort_keys.push_back(iota); sort_types.push_back(xla::S32); @@ -202,16 +214,18 @@ class UniqueOpBase : public XlaOpKernel { gather_dim_numbers.add_collapsed_slice_dims(0); auto permuted = xla::Gather(aux, perm, gather_dim_numbers, {1, product}); // Tail is everything except for first element. - auto tail = xla::SliceInDim(permuted, 1, leading_size, 1, 0); + auto tail = xla::SliceInDim(permuted, 1, leading_size, + xla::DynExpr::one, leading_expr, 1, 0); // Init is everything except for last element. - auto init = xla::SliceInDim(permuted, 0, leading_size - 1, 1, 0); + auto init = xla::SliceInDim(permuted, 0, leading_size - 1, + xla::DynExpr::zero, *leading_expr - 1, 1, 0); auto ne = xla::Compare(tail, init, xla::ComparisonDirection::kNe); auto reduce = xla::Reduce(ne, xla::ConstantR0(ctx->builder(), false), CreateScalarOrComputation(xla::PRED, ctx->builder()), {1}); auto mask = xla::ConvertElementType(reduce, xla::S32); mask = xla::PadInDim(mask, xla::One(ctx->builder(), xla::S32), 0, 1, 0); - auto iperm = RollingSelectR1(ctx, perm, mask, leading_size); + auto iperm = RollingSelectR1(ctx, perm, mask, leading_size, leading_expr); auto sort_by_iperm = xla::Sort({iperm, mask, perm}, @@ -232,12 +246,15 @@ class UniqueOpBase : public XlaOpKernel { /*is_stable=*/true); auto mask_permute = xla::GetTupleElement(mask_sort, 1); permuted = xla::Gather(aux, mask_permute, gather_dim_numbers, {1, product}); - auto result_data = xla::Reshape(permuted, aux_shape.dimensions()); + auto result_data = + xla::Reshape(permuted, aux_shape.dimensions(), aux_shape.expressions()); result_data = MoveAxis(result_data, 0, axis, aux_shape); result_data = xla::SetDimensionSize(result_data, dynamic_size, axis); ctx->SetOutput(0, result_data); auto imask = CumSumR1(ctx, mask, leading_size); - imask = xla::Sub(imask, xla::One(ctx->builder(), xla::S32), {}); + auto one = xla::One(ctx->builder(), xla::S32); + auto one_broadcast = xla::Broadcast(one, {leading_size}, {leading_expr}); + imask = xla::Sub(imask, one_broadcast, {}); auto idx = xla::GetTupleElement( xla::Sort({perm_sort, imask}, xla::CreateScalarLtComputation({xla::S32, xla::S32}, diff --git a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc index cca29f7f585907..d83c77e1b5a68f 100644 --- a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc @@ -69,7 +69,8 @@ class UnpackOp : public XlaOpKernel { limit_indices[axis] = i + 1; auto slice = xla::Slice(input, start_indices, limit_indices, strides); // Reshape to drop the 'axis' dimension. - auto result = xla::Reshape(slice, output_shape.dim_sizes()); + auto result = xla::Reshape(slice, output_shape.dim_sizes(), + output_shape.get_expressions()); ctx->SetOutput(i, result); } } diff --git a/tensorflow/compiler/tf2xla/kernels/where_op.cc b/tensorflow/compiler/tf2xla/kernels/where_op.cc index f97e6d5077efa7..29bcf4fa0769ad 100644 --- a/tensorflow/compiler/tf2xla/kernels/where_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/where_op.cc @@ -165,7 +165,13 @@ absl::StatusOr CompileWhereWithSort(XlaOpKernelContext* ctx) { xla::ShapeUtil::MakeShape(xla::S32, input_shape.dimensions()); int64_t flattened_size = xla::Product(iota_shape.dimensions()); - XlaOp reshaped_condition = xla::Reshape(condition, {flattened_size}); + xla::DynExpr* flattened_expr = xla::DynExpr::one; + for (auto e : iota_shape.expressions()){ + flattened_expr = *flattened_expr * *e; + } + flattened_expr = flattened_expr->s(); + XlaOp reshaped_condition = + xla::Reshape(condition, {flattened_size}, {flattened_expr}); XlaOp zeros = xla::ZerosLike(reshaped_condition); XlaOp compared = xla::Ne(reshaped_condition, zeros); @@ -175,7 +181,7 @@ absl::StatusOr CompileWhereWithSort(XlaOpKernelContext* ctx) { // indices of each element. for (int64_t axis = 0; axis < iota_shape.dimensions_size(); ++axis) { XlaOp iota = xla::Iota(ctx->builder(), iota_shape, axis); - XlaOp reshaped = xla::Reshape(iota, {flattened_size}); + XlaOp reshaped = xla::Reshape(iota, {flattened_size}, {flattened_expr}); to_sort.push_back(reshaped); types_to_sort.push_back(xla::S32); } @@ -186,7 +192,8 @@ absl::StatusOr CompileWhereWithSort(XlaOpKernelContext* ctx) { std::vector to_concat; for (int64_t i = 0; i < iota_shape.dimensions_size(); ++i) { XlaOp index_single_dim = xla::GetTupleElement(sorted, i + 1); - to_concat.push_back(xla::Reshape(index_single_dim, {flattened_size, 1})); + to_concat.push_back(xla::Reshape(index_single_dim, {flattened_size, 1}, + {flattened_expr, xla::DynExpr::one})); } XlaOp result = xla::ConcatInDim(ctx->builder(), to_concat, 1); @@ -214,7 +221,13 @@ absl::StatusOr CompileWhereWithPrefixSum(XlaOpKernelContext* ctx) { TF_ASSIGN_OR_RETURN(xla::Shape input_shape, b->GetShape(condition)); int64_t flattened_size = xla::Product(input_shape.dimensions()); - XlaOp reshaped_condition = xla::Reshape(condition, {flattened_size}); + xla::DynExpr* flattened_expr = xla::DynExpr::one; + for (auto e : input_shape.expressions()) { + flattened_expr = *flattened_expr * *e; + } + flattened_expr = flattened_expr->s(); + XlaOp reshaped_condition = + xla::Reshape(condition, {flattened_size}, {flattened_expr}); XlaOp zeros = xla::ZerosLike(reshaped_condition); XlaOp preds = xla::ConvertElementType(xla::Ne(reshaped_condition, zeros), S32); @@ -253,7 +266,8 @@ absl::StatusOr CompileWhereWithPrefixSum(XlaOpKernelContext* ctx) { XlaOp out_idxs = xla::Select(xla::Ne(prefix_sum, prefix_sum_shifted), /*on_true=*/prefix_sum - xla::One(b, S32), /*on_false=*/oob_idx); - out_idxs = xla::Reshape(out_idxs, {flattened_size, 1}); + out_idxs = xla::Reshape(out_idxs, {flattened_size, 1}, + {flattened_expr, xla::DynExpr::one}); // tf.where returns an array of multidimensional indices where the condition // is true. For example: @@ -280,7 +294,8 @@ absl::StatusOr CompileWhereWithPrefixSum(XlaOpKernelContext* ctx) { iotas_to_concat.reserve(iota_shape.dimensions_size()); for (int64_t axis = 0; axis < iota_shape.dimensions_size(); ++axis) { iotas_to_concat.push_back( - xla::Reshape(xla::Iota(b, iota_shape, axis), {flattened_size, 1})); + xla::Reshape(xla::Iota(b, iota_shape, axis), {flattened_size, 1}, + {flattened_expr, xla::DynExpr::one})); } XlaOp iotas = xla::ConcatInDim(b, iotas_to_concat, /*dimension=*/1); diff --git a/tensorflow/compiler/tf2xla/layout_util.cc b/tensorflow/compiler/tf2xla/layout_util.cc index b000c49f1f962e..fec1ce280cf922 100644 --- a/tensorflow/compiler/tf2xla/layout_util.cc +++ b/tensorflow/compiler/tf2xla/layout_util.cc @@ -133,6 +133,7 @@ absl::StatusOr ReshapeWithCorrectRepresentationAndSharding( if (xla::ShapeUtil::Compatible(original_shape, to_shape)) { for (int64_t i = 0; i < original_shape.dimensions().size(); ++i) { to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i)); + to_shape.set_expression(i, original_shape.expressions(i)); } } return xla::Reshape(to_shape, original); diff --git a/tensorflow/compiler/tf2xla/lib/broadcast.cc b/tensorflow/compiler/tf2xla/lib/broadcast.cc index f815b91d04be33..c866c4429d4818 100644 --- a/tensorflow/compiler/tf2xla/lib/broadcast.cc +++ b/tensorflow/compiler/tf2xla/lib/broadcast.cc @@ -32,9 +32,10 @@ limitations under the License. namespace tensorflow { -absl::StatusOr BroadcastTo(xla::XlaOp input, - absl::Span output_dims) { - return xla::BroadcastTo(input, output_dims); +absl::StatusOr BroadcastTo( + xla::XlaOp input, absl::Span output_dims, + absl::Span output_exprs) { + return xla::BroadcastTo(input, output_dims, output_exprs); } absl::Status BroadcastOpsToSame(xla::XlaOp* lhs, xla::XlaOp* rhs) { diff --git a/tensorflow/compiler/tf2xla/lib/broadcast.h b/tensorflow/compiler/tf2xla/lib/broadcast.h index 60630971e27466..ee56975b664f7b 100644 --- a/tensorflow/compiler/tf2xla/lib/broadcast.h +++ b/tensorflow/compiler/tf2xla/lib/broadcast.h @@ -29,8 +29,9 @@ namespace tensorflow { // Forwards to xla::BroadcastTo. // TODO(cheshire): Call the underlying function directly. -absl::StatusOr BroadcastTo(xla::XlaOp input, - absl::Span output_dims); +absl::StatusOr BroadcastTo( + xla::XlaOp input, absl::Span output_dims, + absl::Span output_exprs = {}); // Forwards to xla::BroadcastOpsToSame. absl::Status BroadcastOpsToSame(xla::XlaOp* lhs, xla::XlaOp* rhs); diff --git a/tensorflow/compiler/tf2xla/lib/data_format.cc b/tensorflow/compiler/tf2xla/lib/data_format.cc index 2473b97af4c2bd..38116fddb200af 100644 --- a/tensorflow/compiler/tf2xla/lib/data_format.cc +++ b/tensorflow/compiler/tf2xla/lib/data_format.cc @@ -53,7 +53,12 @@ absl::StatusOr Contract(xla::XlaOp input, int64_t dim) { input_shape.dimensions().end() - 1); contracted_shape[dim] *= 4; - return xla::Reshape(xla::Transpose(input, permutation), contracted_shape); + std::vector contracted_exprs( + input_shape.expressions().begin(), input_shape.expressions().end() - 1); + contracted_exprs[dim] = (*(contracted_exprs[dim]) * *xla::DynExpr::_(4))->s(); + + return xla::Reshape(xla::Transpose(input, permutation), contracted_shape, + contracted_exprs); } absl::StatusOr Expand(xla::XlaOp input, int64_t dim) { @@ -85,7 +90,7 @@ absl::StatusOr Expand(xla::XlaOp input, int64_t dim) { } permutation.push_back(dim + 1); - return xla::Transpose(xla::Reshape(input, expanded_shape), permutation); + return xla::Transpose(xla::Reshape(input, expanded_shape, {}), permutation); } } // namespace diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index 6a67cfa237af70..fa417215032662 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -1143,16 +1143,22 @@ xla::Shape GetShape(shape_inference::ShapeHandle shape_handle, } std::vector dims; std::vector dynamic_dims; + std::vector expressions; for (int i = 0, rank = c->Rank(shape_handle); i < rank; ++i) { bool is_dynamic = !c->ValueKnown(c->Dim(shape_handle, i)); + int dynamic_multiplier = c->DynamicRatio(c->Dim(shape_handle, i)); dynamic_dims.push_back(is_dynamic); + expressions.push_back(dynamic_multiplier * *xla::DynExpr::V(1)); dims.push_back(is_dynamic ? xla::Shape::kUnboundedSize : c->Value(c->Dim(shape_handle, i))); } - return xla::Shape( + xla::Shape sh( // Type matters only for indices. S64 is the widest possible type. xla::PrimitiveType::S64, dims, absl::InlinedVector(dynamic_dims.begin(), dynamic_dims.end())); + + sh.set_expressions(expressions); + return sh; } REGISTER_OP("XlaGather") diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc index 0d7549d81c20f6..f2dccaea7b1cac 100644 --- a/tensorflow/compiler/tf2xla/shape_util.cc +++ b/tensorflow/compiler/tf2xla/shape_util.cc @@ -100,6 +100,9 @@ absl::Status XLAShapeToTensorShape(const xla::Shape& shape, for (int i = 0; i < shape.dimensions().size(); ++i) { TF_RETURN_IF_ERROR(tensor_shape->AddDimWithStatus(shape.dimensions(i))); } + std::vector bexprs(shape.expressions().begin(), + shape.expressions().end()); + tensor_shape->set_expressions(bexprs); return absl::OkStatus(); } @@ -167,6 +170,7 @@ xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, } int rank = tensor_shape.dims(); std::vector dimensions(rank); + std::vector expressions(rank); std::vector layout(rank); for (int d = 0; d < rank; ++d) { dimensions[d] = tensor_shape.dim_size(d); @@ -175,11 +179,13 @@ xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, "shape; returning unknown sentinel value"; return xla::ShapeUtil::MakeShapeWithDenseLayout(type, {0}, {0}); } + expressions[d] = tensor_shape.get_expression(d); } // XLA uses minor-to-major; Tensorflow uses major-to-minor. std::iota(layout.rbegin(), layout.rend(), 0); xla::Shape result = xla::ShapeUtil::MakeShapeWithDenseLayout(type, dimensions, layout); + result.set_expressions(expressions); return result; } @@ -200,18 +206,29 @@ absl::StatusOr TensorShapeToXLAShape( return out; } +inline static int var_id = 1; + xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, const TensorShape& tensor_shape) { int rank = tensor_shape.dims(); std::vector dimensions(rank); std::vector layout(rank); + std::vector expressions(rank); + for (int d = 0; d < rank; ++d) { dimensions[d] = tensor_shape.dim_size(d); + expressions[d] = (d < tensor_shape.get_expressions().size()) + ? tensor_shape.get_expression(d) + : xla::DynExpr::_(dimensions[d]); } + // XLA uses minor-to-major; Tensorflow uses major-to-minor. std::iota(layout.rbegin(), layout.rend(), 0); - return xla::ShapeUtil::MakeShapeWithDenseLayout(type, dimensions, layout); + auto shape = + xla::ShapeUtil::MakeShapeWithDenseLayout(type, dimensions, layout); + shape.set_expressions(expressions); + return shape; } absl::StatusOr> GetShapeLayoutVector(const xla::Shape& shape) { diff --git a/tensorflow/compiler/tf2xla/xla_argument.h b/tensorflow/compiler/tf2xla/xla_argument.h index 9e2eccd29b1885..9873082c3f4d6e 100644 --- a/tensorflow/compiler/tf2xla/xla_argument.h +++ b/tensorflow/compiler/tf2xla/xla_argument.h @@ -116,6 +116,7 @@ struct XlaArgument { // Returns the dimension sizes for either TensorShape or xla::Shape. std::vector DimensionSizes() const; + std::vector DimensionExpressions() const; absl::InlinedVector DimensionSizesAsInlinedVector() const; // Returns the human-readable string for either TensorShape or xla::Shape. diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index b7cff00c8a0bfe..84a5b1e5c05eca 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -511,6 +511,14 @@ std::vector XlaCompiler::Argument::DimensionSizes() const { } } +std::vector XlaCompiler::Argument::DimensionExpressions() const { + if (absl::holds_alternative(shape)) { + return std::get(shape).get_expressions(); + } else { + return xla::SpanToVector(std::get(shape).expressions()); + } +} + absl::InlinedVector XlaCompiler::Argument::DimensionSizesAsInlinedVector() const { if (absl::holds_alternative(shape)) { @@ -839,7 +847,9 @@ absl::Status XlaCompiler::CompileFunction( std::vector{tensor_shape}); } } else { + auto* val_ptr = std::get_if(&args[i].shape); TensorShape tensor_shape = std::get(args[i].shape); + AttrSlice n_attrs = fbody->arg_nodes[i]->attrs(); fbody->arg_nodes[i]->ClearAttr("_output_shapes"); fbody->arg_nodes[i]->AddAttr("_output_shapes", std::vector{tensor_shape}); @@ -1083,6 +1093,7 @@ absl::Status XlaCompiler::BuildArguments( TF_RET_CHECK(absl::holds_alternative(arg.shape)); // TODO(phawkins): this code assumes that resource arguments do not // alias. + auto* val_ptr = std::get_if(&arg.shape); XlaResource* resource = context->AddResource(std::make_unique( arg.resource_kind, i, arg.name, arg.type, @@ -1203,6 +1214,10 @@ absl::Status XlaCompiler::BuildArguments( xla::XlaScopedShardingAssignment assign_sharding( builder, it == arg_shardings.end() ? std::optional() : it->second); + auto& arg = args[input_to_args->at(i)]; + xla::OpMetadata arg_metadata; + arg_metadata.set_op_name(arg.node_name); + builder->SetOneShotOpMetadata(arg_metadata); if (is_entry_computation) { // Add an entry to is_same_across_replicas for every leaf buffer. std::vector is_same_across_replicas( @@ -1222,10 +1237,9 @@ absl::Status XlaCompiler::BuildArguments( // Fill in the handles in non-constant arguments, and reshape parameters // back to their correct shapes. - VLOG(2) << "XLA computation inputs:"; for (std::vector::size_type i = 0; i < input_to_args->size(); ++i) { const XlaCompiler::Argument& arg = args[input_to_args->at(i)]; - VLOG(2) << " XLA arg " << i + VLOG(2) << " XLA arg " << i << " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i]) << " name: " << arg.name << " TF arg " << input_to_args->at(i) << " node name: " << arg.node_name @@ -1251,7 +1265,9 @@ absl::Status XlaCompiler::BuildArguments( // return values of functions, and then reshape unconditionally. if (is_entry_computation) { arg_expression = XlaExpression::XlaOp( - xla::Reshape(arg_handles[i], arg.DimensionSizes()), arg.type); + xla::Reshape(arg_handles[i], arg.DimensionSizes(), + arg.DimensionExpressions()), + arg.type); } else { arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type); if (arg.value_bound) { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 4a570827029330..e022b2fbec258b 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -736,7 +736,8 @@ absl::Status AssignVariableTensor(const Tensor& tensor, DataType type, xla::Shape xla_shape; TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape)); if (!xla::ShapeUtil::Compatible(xla_shape, representation_shape)) { - handle = xla::Reshape(handle, representation_shape.dimensions()); + handle = xla::Reshape(handle, representation_shape.dimensions(), + representation_shape.expressions()); } variable->SetRepresentationShape(representation_shape); return variable->SetValue(handle); diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 2f0ff5e91867f1..72bd6f712672e3 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1740,6 +1740,7 @@ tf_cuda_library( "@local_xla//xla/tsl/framework:cancellation", "@local_xla//xla/tsl/util:command_line_flags", "@local_xla//xla/tsl/util:device_name_utils", + "@local_xla//xla:shape_util", ] + if_cuda([ "@local_config_cuda//cuda:cudnn_header", ]) + if_static( diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index 6820a5ddd696d3..0a80e85ba83955 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -327,7 +327,7 @@ void ConsiderConstantFoldableNode( std::unordered_map>* constant_control_deps, std::unordered_map>* shape_replacement_map, bool* internal_node_inserted) { - if (!IsConstantFoldable(n, opts.shape_map, opts.consider, + if (!IsConstantFoldable(n, opts.shape_map, opts.consider, opts.max_constant_size_in_bytes, shape_replacement_map)) { return; diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index 09142e303e3e13..2c0d2ba5ee7ee9 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -62,6 +62,7 @@ package( exports_files( srcs = [ "allocator_registry.h", + "batch_size_resource.h", "collective.h", "control_flow.h", "dataset.h", @@ -195,6 +196,7 @@ filegroup( "allocator.h", "allocator_registry.h", "attr_value_util.h", + "batch_size_resource.h", "bfloat16.h", "bounds_check.h", "cancellation.h", @@ -340,6 +342,8 @@ filegroup( "tensor_key.h", "tensor_shape.cc", "tensor_shape.h", + "tensor_shape_expr.cc", + "tensor_shape_expr.h", "tensor_types.h", "tensor_util.h", "tracking_allocator.h", @@ -718,6 +722,23 @@ cc_library( "//tensorflow/core/platform:statusor", "//tensorflow/core/util:overflow", "@eigen_archive//:eigen3", + "@local_xla//xla:shape_util", + ], + alwayslink = 1, +) + +cc_library( + name = "tensor_shape_expr", + srcs = ["tensor_shape_expr.cc"], + hdrs = ["tensor_shape_expr.h"], + visibility = [ + "//tensorflow/core:__pkg__", + "//tensorflow/core/grappler:__subpackages__", + "//tensorflow/core/runtime_fallback:__subpackages__", + "//tensorflow/core/tfrt/utils:__subpackages__", + ], + deps = [ + ":tensor_shape_proto_cc", ], alwayslink = 1, ) @@ -921,6 +942,7 @@ cc_library( ":node_def_util", ":op_def_proto_cc", ":tensor_shape", + ":tensor_shape_expr", ":tensor_shape_proto_cc", "//tensorflow/core/lib/core:errors", "//tensorflow/core/lib/core:status", diff --git a/tensorflow/core/framework/batch_size_resource.h b/tensorflow/core/framework/batch_size_resource.h new file mode 100644 index 00000000000000..e457b65e3293e9 --- /dev/null +++ b/tensorflow/core/framework/batch_size_resource.h @@ -0,0 +1,16 @@ +#include "tensorflow/core/framework/resource_mgr.h" + +namespace tensorflow { +const string BatchSizeResourceName = "BatchSizeResource_"; +class BatchSizeResource : public ResourceBase { + public: + ~BatchSizeResource() override { + VLOG(1) << "BatchSizeResource destroyed with batch size: " << batch_size_; + } + string DebugString() const override { return BatchSizeResourceName; } + void SetBatchSize(size_t s) { batch_size_ = s; } + size_t GetBatchSize() { return batch_size_; } + private: + size_t batch_size_ = 0; +}; +} diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index aefa2a416310e2..42cf77245b8e1b 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -2369,7 +2369,13 @@ absl::Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c, } else if (dim_y.SameHandle(dim_x)) { dims.push_back(dim_x); } else if (!c->ValueKnown(dim_x) && !c->ValueKnown(dim_y)) { - dims.push_back(c->UnknownDim()); + DimensionHandle merged; + absl::Status s = c->Merge(dim_x, dim_y, &merged); + if (s.ok()) { + dims.push_back(merged); + } else { + dims.push_back(c->UnknownDim()); + } } else { if (!incompatible_shape_error) { *out = c->UnknownShape(); diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index b63269f68c3368..b1336bf4398844 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -248,6 +248,10 @@ void InferenceContext::ShapeHandleToProto(ShapeHandle handle, dim_shape->set_size(Value(dim)); } else { dim_shape->set_size(-1); + // Serialize expression if available. + if (DimExpr* expr = GetDimExpr(dim)) { + expr->ToProto(dim_shape->mutable_expr()); + } } } } @@ -282,6 +286,36 @@ DimensionHandle InferenceContext::NumElements(ShapeHandle s) { } } +DimensionHandle InferenceContext::UnknownDimWithExpr( + std::unique_ptr expr) { + DimExpr* owned = shape_manager_.OwnExpr(std::move(expr)); + return shape_manager_.MakeDim(kUnknownDim, /*dynamic_ratio*/0, owned); +} + +DimExpr* InferenceContext::GetDimExpr(DimensionHandle d) const { + if (!d.IsSet()) return nullptr; + return d->expr_; +} + +DimExpr* InferenceContext::MakeConstExpr(int64_t v) { + return shape_manager_.OwnExpr(std::make_unique(v)); +} + +DimExpr* InferenceContext::ExprForDim(DimensionHandle d) { + if (!d.IsSet()) return nullptr; + + // If already tagged with expr, use it. + if (DimExpr* e = GetDimExpr(d)) return e; + + // Known dim -> const expr. + if (ValueKnown(d)) { + return MakeConstExpr(Value(d)); + } + + // Unknown dim with no expr -> cannot form expression. + return nullptr; +} + string InferenceContext::DebugString(ShapeHandle s) { if (RankKnown(s)) { std::vector vals; @@ -293,7 +327,7 @@ string InferenceContext::DebugString(ShapeHandle s) { } string InferenceContext::DebugString(DimensionHandle d) { - return ValueKnown(d) ? strings::StrCat(Value(d)) : "?"; + return ValueKnown(d) ? strings::StrCat(Value(d), strings::StrCat("~",DynamicRatio(d))) : "?"; } string InferenceContext::DebugString() const { @@ -895,7 +929,12 @@ absl::Status InferenceContext::MakeShapeFromPartialTensorShape( for (int i = 0; i < num_dims; ++i) { // -1 is unknown in PartialTensorShape and in InferenceContext, so this size // can be passed directly to MakeDim. - dims[i] = MakeDim(partial_shape.dim_size(i)); + if(i == 0){ + dims[i] = MakeDim(partial_shape.dim_size(i), 1); + } + else { + dims[i] = MakeDim(partial_shape.dim_size(i)); + } } return ReturnCreatedShape(dims, out); } @@ -923,8 +962,38 @@ absl::Status InferenceContext::MakeShapeFromShapeProto( const TensorShapeProto& proto, ShapeHandle* out) { *out = nullptr; TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(proto)); - PartialTensorShape partial_shape(proto); - return MakeShapeFromPartialTensorShape(partial_shape, out); + + if (proto.unknown_rank()) { + *out = UnknownShape(); + return absl::OkStatus(); + } + + std::vector dims; + dims.reserve(proto.dim_size()); + for (int i = 0; i < proto.dim_size(); ++i) { + const auto& dim_proto = proto.dim(i); + if (dim_proto.size() >= 0) { + // Known dimension + dims.push_back(MakeDim(dim_proto.size())); + } else { + // Unknown dimension - check for expression + if (dim_proto.has_expr() && dim_proto.expr().node_type_case() != + ExpressionProto::NODE_TYPE_NOT_SET) { + // Deserialize expression + std::unique_ptr expr = DimExpr::FromProto(dim_proto.expr()); + if (expr) { + DimExpr* owned = shape_manager_.OwnExpr(std::move(expr)); + dims.push_back(shape_manager_.MakeDim(kUnknownDim,/*dynamic_ratio */ 0, owned)); + } else { + dims.push_back(UnknownDim()); + } + } else { + dims.push_back(UnknownDim()); + } + } + } + *out = MakeShape(dims); + return absl::OkStatus(); } absl::Status InferenceContext::GetScalarFromTensor(const Tensor* t, @@ -1030,24 +1099,40 @@ absl::Status InferenceContext::Divide(DimensionHandle dividend, DimensionOrConstant divisor, bool evenly_divisible, DimensionHandle* out) { - const int64_t divisor_value = Value(divisor); - if (divisor_value == 1) { + const bool dividend_known = ValueKnown(dividend); + const bool divisor_known = ValueKnown(divisor); + + // Validate divisor if known. + if (divisor_known && Value(divisor) <= 0) { + return errors::InvalidArgument("Divisor must be positive but is ", + Value(divisor)); + } + // Fast-path: x / 1 = x + if (divisor_known && Value(divisor) == 1) { *out = dividend; - } else if (!ValueKnown(dividend) || - (divisor.dim.IsSet() && !ValueKnown(divisor.dim))) { - *out = UnknownDim(); - } else { + return absl::OkStatus(); + } + // If both known, do numeric divide. + if (dividend_known && divisor_known) { const int64_t v = Value(dividend); - if (divisor_value <= 0) { - return errors::InvalidArgument("Divisor must be positive but is ", - divisor_value); - } - if (evenly_divisible && (v % divisor_value) != 0) { + const int64_t d = Value(divisor); + if (evenly_divisible && (v % d) != 0) { return errors::InvalidArgument( - "Dimension size must be evenly divisible by ", divisor_value, - " but is ", v); + "Dimension size must be evenly divisible by ", d, " but is ", v); } - *out = MakeDim(v / divisor_value); + *out = MakeDim(v / d); + return absl::OkStatus(); + } + // At least one operand unknown: try to build expression. + DimExpr* lhs = ExprForDim(dividend); + DimExpr* rhs = divisor.dim.IsSet() ? ExprForDim(divisor.dim) + : MakeConstExpr(divisor.val); + if (lhs && rhs) { + DimExpr* node = shape_manager_.OwnExpr( + std::make_unique(lhs, rhs)); + *out = shape_manager_.MakeDim(kUnknownDim, /*dynamic_ratio*/0, node); + } else { + *out = UnknownDim(); // Can't form expr. } return absl::OkStatus(); } @@ -1055,26 +1140,41 @@ absl::Status InferenceContext::Divide(DimensionHandle dividend, absl::Status InferenceContext::Add(DimensionHandle first, DimensionOrConstant second, DimensionHandle* out) { - const int64_t first_value = Value(first); - const int64_t second_value = Value(second); - // Special cases. - if (first_value == 0) { + const bool first_known = ValueKnown(first); + const bool second_known = ValueKnown(second); + + // Fast-path: x + 0 = x + if (first_known && Value(first) == 0) { *out = MakeDim(second); - } else if (second_value == 0) { + return absl::OkStatus(); + } + if (second_known && Value(second) == 0) { *out = first; - } else if (first_value == kUnknownDim || second_value == kUnknownDim) { - *out = UnknownDim(); - } else { - // Invariant: Both values are known and positive. Still in run-time we can - // get pair of values which cannot be store in output. Check below will - // report error. We still need to avoid undefined behavior of signed - // overflow and use unsigned addition. - const int64_t sum = static_cast(first_value) + second_value; + return absl::OkStatus(); + } + + // If both known, do numeric add. + if (first_known && second_known) { + const int64_t sum = static_cast(Value(first)) + + static_cast(Value(second)); if (sum < 0) { return errors::InvalidArgument("Dimension size overflow from adding ", - first_value, " and ", second_value); + Value(first), " and ", Value(second)); } *out = MakeDim(sum); + return absl::OkStatus(); + } + + // At least one operand unknown: try to build expression. + DimExpr* lhs = ExprForDim(first); + DimExpr* rhs = + second.dim.IsSet() ? ExprForDim(second.dim) : MakeConstExpr(second.val); + + if (lhs && rhs) { + DimExpr* node = shape_manager_.OwnExpr(std::make_unique(lhs, rhs)); + *out = shape_manager_.MakeDim(kUnknownDim, /*dynamic_ratio*/ 0, node); + } else { + *out = UnknownDim(); // Can't form expr. } return absl::OkStatus(); } @@ -1082,22 +1182,34 @@ absl::Status InferenceContext::Add(DimensionHandle first, absl::Status InferenceContext::Subtract(DimensionHandle first, DimensionOrConstant second, DimensionHandle* out) { - const int64_t first_value = Value(first); - const int64_t second_value = Value(second); - // Special cases. - if (second_value == 0) { + const bool first_known = ValueKnown(first); + const bool second_known = ValueKnown(second); + // Fast-path: x - 0 = x + if (second_known && Value(second) == 0) { *out = first; - } else if (first_value == kUnknownDim || second_value == kUnknownDim) { - *out = UnknownDim(); - } else { - // Invariant: Both values are known, first_value is non-negative, and - // second_value is positive. + return absl::OkStatus(); + } + // If both known, do numeric subtract. + if (first_known && second_known) { + const int64_t first_value = Value(first); + const int64_t second_value = Value(second); if (first_value < second_value) { return errors::InvalidArgument( "Negative dimension size caused by subtracting ", second_value, " from ", first_value); } *out = MakeDim(first_value - second_value); + return absl::OkStatus(); + } + // At least one operand unknown: try to build expression. + DimExpr* lhs = ExprForDim(first); + DimExpr* rhs = + second.dim.IsSet() ? ExprForDim(second.dim) : MakeConstExpr(second.val); + if (lhs && rhs) { + DimExpr* node = shape_manager_.OwnExpr(std::make_unique(lhs, rhs)); + *out = shape_manager_.MakeDim(kUnknownDim, /*dynamic_ratio*/ 0, node); + } else { + *out = UnknownDim(); // Can't form expr. } return absl::OkStatus(); } @@ -1105,21 +1217,31 @@ absl::Status InferenceContext::Subtract(DimensionHandle first, absl::Status InferenceContext::Multiply(DimensionHandle first, DimensionOrConstant second, DimensionHandle* out) { + const bool first_known = ValueKnown(first); + const bool second_known = ValueKnown(second); const int64_t first_value = Value(first); const int64_t second_value = Value(second); - // Special cases. - if (first_value == 0) { + + // Fast-paths for identity and zero cases. + if (first_known && first_value == 0) { *out = first; - } else if (second_value == 0) { + return absl::OkStatus(); + } + if (second_known && second_value == 0) { *out = MakeDim(second); - } else if (first_value == 1) { + return absl::OkStatus(); + } + if (first_known && first_value == 1) { *out = MakeDim(second); - } else if (second_value == 1) { + return absl::OkStatus(); + } + if (second_known && second_value == 1) { *out = first; - } else if (first_value == kUnknownDim || second_value == kUnknownDim) { - *out = UnknownDim(); - } else { - // Invariant: Both values are known and greater than 1. + return absl::OkStatus(); + } + + // If both known, do numeric multiply. + if (first_known && second_known) { const int64_t product = MultiplyWithoutOverflow(first_value, second_value); if (product < 0) { return errors::InvalidArgument( @@ -1127,6 +1249,19 @@ absl::Status InferenceContext::Multiply(DimensionHandle first, first_value, " and ", second_value); } *out = MakeDim(product); + return absl::OkStatus(); + } + + // At least one operand unknown: try to build expression. + DimExpr* lhs = ExprForDim(first); + DimExpr* rhs = + second.dim.IsSet() ? ExprForDim(second.dim) : MakeConstExpr(second.val); + + if (lhs && rhs) { + DimExpr* node = shape_manager_.OwnExpr(std::make_unique(lhs, rhs)); + *out = shape_manager_.MakeDim(kUnknownDim, /*dynamic_ratio*/ 0, node); + } else { + *out = UnknownDim(); // Can't form expr. } return absl::OkStatus(); } diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index 8bfd301d860de1..70d09d8fde5e93 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -20,6 +20,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/core/framework/full_type.pb.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor_shape_expr.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/macros.h" @@ -116,13 +117,16 @@ class InferenceContext; class Dimension { private: Dimension(); - Dimension(int64_t value); + Dimension(int64_t value, int64_t dynamic_ratio = 0, DimExpr* expr = nullptr); ~Dimension() {} const int64_t value_; + const int64_t dynamic_ratio_; + DimExpr* expr_; friend class InferenceContext; friend class ShapeManager; + friend class ::tensorflow::grappler::SymbolicShapeManager; Dimension(const Dimension&) = delete; void operator=(const Dimension&) = delete; }; @@ -439,6 +443,9 @@ class InferenceContext { static inline int64_t Value(DimensionOrConstant d) { return d.dim.IsSet() ? d.dim->value_ : d.val; } + static inline int64_t DynamicRatio(DimensionOrConstant d) { + return d.dim->dynamic_ratio_ ; + } static inline bool ValueKnown(DimensionOrConstant d) { return Value(d) != kUnknownDim; } @@ -572,12 +579,26 @@ class InferenceContext { // Returns a new dimension of the given size. The returned value is owned by // this context. - inline DimensionHandle MakeDim(DimensionOrConstant d) { - return shape_manager_.MakeDim(d); + inline DimensionHandle MakeDim(DimensionOrConstant d, int64_t dynamic_ratio = 0) { + return shape_manager_.MakeDim(d, dynamic_ratio); } inline DimensionHandle UnknownDim() { return MakeDim(kUnknownDim); } + // Create a new unknown dimension (size = -1) tagged with a DimExpr. + // The expression is owned by this context's ShapeManager. + DimensionHandle UnknownDimWithExpr(std::unique_ptr expr); + // Return the expression pointer for a dimension, or nullptr if none. + DimExpr* GetDimExpr(DimensionHandle d) const; + // Creates a constant DimExpr node for the given value. + // The expression is owned by this context's ShapeManager. + DimExpr* MakeConstExpr(int64_t v); + // Returns the Expr representation for the given dimension: + // - If dim has an expr, returns it + // - If dim is known, returns a new Const expr + // - If dim is unknown with no expr, returns nullptr + DimExpr* ExprForDim(DimensionHandle d); + // Returns in a scalar value from an input tensor . The input tensor // must be a 0-dimensional int32 or int64 tensor. Caller must ensure that the // input tensor is not NULL. @@ -743,8 +764,6 @@ class InferenceContext { // Adds new outputs; useful when mutating the graph. absl::Status ExpandOutputs(int new_output_size); - - private: // Creates and stores shapes for use in InferenceContext. class ShapeManager { public: @@ -760,21 +779,31 @@ class InferenceContext { // Returns a new dimension of the given size. The returned value // is owned by this class. - inline DimensionHandle MakeDim(DimensionOrConstant d) { + inline DimensionHandle MakeDim(DimensionOrConstant d, int64_t dynamic_ratio = 0, DimExpr* expr = nullptr) { if (d.dim.IsSet()) { return d.dim; } else { - all_dims_.push_back(new Dimension(d.val)); + all_dims_.push_back(new Dimension(d.val, dynamic_ratio, expr)); return all_dims_.back(); } } + // Takes ownership of an expression and returns a raw pointer to it. + DimExpr* OwnExpr(std::unique_ptr expr) { + if (!expr) return nullptr; + DimExpr* ptr = expr.get(); + all_exprs_.push_back(std::move(expr)); + return ptr; + } private: std::vector all_shapes_; // values are owned. std::vector all_dims_; // values are owned. + std::vector> all_exprs_; // expressions are owned. }; + private: friend class ::tensorflow::grappler::GraphProperties; + friend class ::tensorflow::grappler::SymbolicShapeManager; friend class ShapeInferenceTest; // For testing Relax functions. friend class ShapeInferenceTestutil; // For testing shapes. @@ -888,8 +917,8 @@ class InferenceContext { // ----------------------------------------------------------------------------- // Template and inline method implementations, please ignore -inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {} -inline Dimension::Dimension(int64_t value) : value_(value) { +inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim), dynamic_ratio_(0), expr_(nullptr) {} +inline Dimension::Dimension(int64_t value, int64_t dynamic_ratio, DimExpr* expr) : value_(value), dynamic_ratio_(dynamic_ratio), expr_(expr) { DCHECK(value >= 0 || value == InferenceContext::kUnknownDim) << "Dimension must be non-negative or equal to " "InferenceContext::kUnknownDim but got " diff --git a/tensorflow/core/framework/tensor_shape.cc b/tensorflow/core/framework/tensor_shape.cc index 35c628216ed3c6..31076d0433731f 100644 --- a/tensorflow/core/framework/tensor_shape.cc +++ b/tensorflow/core/framework/tensor_shape.cc @@ -26,6 +26,105 @@ limitations under the License. namespace tensorflow { +xla::DynExpr* ExprFromProto(const ExpressionProto& proto) { + switch (proto.node_type_case()) { + case ExpressionProto::kConstantValue: + return xla::DynExpr::_(proto.constant_value()); + + case ExpressionProto::kVariableId: + return xla::DynExpr::V(proto.variable_id()); + + case ExpressionProto::kAddNode: { + const auto& add = proto.add_node(); + return *ExprFromProto(add.lhs()) + *ExprFromProto(add.rhs()); + } + + case ExpressionProto::kSubNode: { + const auto& sub = proto.sub_node(); + return *ExprFromProto(sub.lhs()) - *ExprFromProto(sub.rhs()); + } + + case ExpressionProto::kMulNode: { + const auto& mul = proto.mul_node(); + return *ExprFromProto(mul.lhs()) * *ExprFromProto(mul.rhs()); + } + + case ExpressionProto::kDivNode: { + const auto& div = proto.div_node(); + return *ExprFromProto(div.lhs()) / *ExprFromProto(div.rhs()); + } + + case ExpressionProto::NODE_TYPE_NOT_SET: + default: + return nullptr; + } +} + +void ExprToProto(xla::DynExpr* expr, ExpressionProto* proto) { + auto e = expr->s(); + if (xla::Constant* c = dynamic_cast(e)) { + proto->set_constant_value(c->get_val()); + } else if (xla::Variable* v = dynamic_cast(e)) { + proto->set_variable_id(v->get_id()); + } else if (xla::Add* a = dynamic_cast(e)) { + auto* add_msg = proto->mutable_add_node(); + ExprToProto(a->get_lhs(), add_msg->mutable_lhs()); + ExprToProto(a->get_rhs(), add_msg->mutable_rhs()); + } else if (xla::Mul* m = dynamic_cast(e)) { + auto* mul_msg = proto->mutable_mul_node(); + ExprToProto(m->get_lhs(), mul_msg->mutable_lhs()); + ExprToProto(m->get_rhs(), mul_msg->mutable_rhs()); + } else if (xla::Sub* s = dynamic_cast(e)) { + auto* sub_msg = proto->mutable_sub_node(); + ExprToProto(s->get_lhs(), sub_msg->mutable_lhs()); + ExprToProto(s->get_rhs(), sub_msg->mutable_rhs()); + } else if (xla::Div* d = dynamic_cast(e)) { + auto* div_msg = proto->mutable_div_node(); + ExprToProto(d->get_lhs(), div_msg->mutable_lhs()); + ExprToProto(d->get_rhs(), div_msg->mutable_rhs()); + } +} + +// Independent helper function to handle the recursion +void BuildExprString(xla::DynExpr* e, std::ostringstream& oss) { + if (xla::Constant* c = dynamic_cast(e)) { + oss << c->get_val(); + } else if (xla::Variable* v = dynamic_cast(e)) { + char letter = 'A' + (v->get_id() - 1); + oss << letter; + } else if (xla::Add* a = dynamic_cast(e)) { + oss << "("; + BuildExprString(a->get_lhs(), oss); + oss << " + "; + BuildExprString(a->get_rhs(), oss); + oss << ")"; + } else if (xla::Mul* m = dynamic_cast(e)) { + oss << "("; + BuildExprString(m->get_lhs(), oss); + oss << " * "; + BuildExprString(m->get_rhs(), oss); + oss << ")"; + } else if (xla::Sub* s = dynamic_cast(e)) { + oss << "("; + BuildExprString(s->get_lhs(), oss); + oss << " - "; + BuildExprString(s->get_rhs(), oss); + oss << ")"; + } else if (xla::Div* d = dynamic_cast(e)) { + oss << "("; + BuildExprString(d->get_lhs(), oss); + oss << " / "; + BuildExprString(d->get_rhs(), oss); + oss << ")"; + } +} + +std::string ExprToString(xla::DynExpr* e) { + std::ostringstream oss; + BuildExprString(e, oss); + return oss.str(); +} + // TensorShape and PartialTensorShape should have no fields beyond // TensorShapeRep. In particular, their sizes should be the same. static_assert(sizeof(TensorShapeRep) == sizeof(TensorShape), @@ -152,6 +251,9 @@ TensorShapeBase::TensorShapeBase(const TensorShapeProto& proto) { for (const auto& d : proto.dim()) { AddDim(d.size()); } + for (const auto& e : proto.expressions()) { + AddExpression(ExprFromProto(e)); + } } } @@ -191,6 +293,9 @@ absl::Status TensorShapeBase::BuildTensorShapeBase( } } } + for (const auto& e : proto.expressions()) { + out->AddExpression(ExprFromProto(e)); + } } return absl::OkStatus(); } @@ -375,6 +480,19 @@ void TensorShapeRep::Clear() { set_data_type(DT_INVALID); } +void TensorShapeRep::set_expression(int d, xla::DynExpr* expr) { + expressions_[d] = expr; +} + +void TensorShapeRep::AddExpression(xla::DynExpr* expr) { + CHECK_LT(expressions_.size(), ndims_byte()); + expressions_.push_back(expr); +} + +void TensorShapeRep::set_expressions(std::vector exprs) { + expressions_ = exprs; +} + void TensorShapeRep::ClearAllButDataType() { if (tag() == REP_OUT_OF_LINE) { delete as64()->dims_; @@ -505,6 +623,9 @@ void TensorShapeBase::UnsafeAddDim(int64_t size, template void TensorShapeBase::AppendShape(const TensorShapeBase& shape) { for (auto d : shape) AddDim(d.size); + for (auto e : shape.get_expressions()){ + AddExpression(e); + } } template @@ -585,6 +706,7 @@ template void TensorShapeBase::set_dim(int d, int64_t size) { CHECK_GE(d, 0); CHECK_LT(d, dims()); + if (get_expressions().size() > d) set_expression(d, xla::DynExpr::_(size)); if (!kIsPartial) { CHECK_GE(size, 0); } @@ -646,6 +768,7 @@ absl::Status TensorShapeBase::SetDimWithStatus(int d, int64_t size) { } } + if (get_expressions().size() > d) set_expression(d, xla::DynExpr::_(size)); return RecomputeNumElements(); } @@ -661,8 +784,28 @@ void TensorShapeBase::RemoveDimRange(int begin, int end) { if (begin >= end) return; absl::InlinedVector vals; AppendTo(*this, &vals); + std::vector new_exprs = get_expressions(); + if (begin < static_cast(new_exprs.size())) { + int64_t expr_end = end; + if (expr_end > static_cast(new_exprs.size())) { + expr_end = new_exprs.size(); + } + if (expr_end > begin) { + new_exprs.erase(new_exprs.begin() + begin, new_exprs.begin() + expr_end); + } + } + vals.erase(vals.begin() + begin, vals.begin() + end); + + // Truncate if the removed dims reduce rank below expression vector size. + const int64_t new_rank = vals.size(); + if (new_exprs.size() > static_cast(new_rank)) { + new_exprs.resize(new_rank); + } + + ClearAllButDataType(); + set_expressions(new_exprs); for (auto dval : vals) { AddDim(dval); } @@ -700,9 +843,28 @@ absl::Status TensorShapeBase::RemoveDimRangeWithStatus(int begin, absl::InlinedVector vals; AppendTo(*this, &vals); + + std::vector new_exprs = get_expressions(); + + if (begin < static_cast(new_exprs.size())) { + int64_t expr_end = end; + if (expr_end > static_cast(new_exprs.size())) { + expr_end = new_exprs.size(); + } + if (expr_end > begin) { + new_exprs.erase(new_exprs.begin() + begin, new_exprs.begin() + expr_end); + } + } + vals.erase(vals.begin() + begin, vals.begin() + end); ClearAllButDataType(); + const int64_t new_rank = vals.size(); + if (new_exprs.size() > static_cast(new_rank)) { + new_exprs.resize(new_rank); + } + + set_expressions(new_exprs); absl::Status s = absl::OkStatus(); for (auto dval : vals) { s.Update(AddDimWithStatus(dval)); @@ -710,7 +872,6 @@ absl::Status TensorShapeBase::RemoveDimRangeWithStatus(int begin, return s; } } - return RecomputeNumElements(); } @@ -731,6 +892,10 @@ void TensorShapeBase::AsProto(TensorShapeProto* proto) const { for (int i = 0; i < dims(); i++) { proto->add_dim()->set_size(dim_size(i)); } + for (int i = 0; i < get_expressions().size(); i++) { + ExpressionProto* eproto = proto->add_expressions(); + ExprToProto(get_expression(i), eproto); + } } } @@ -764,6 +929,11 @@ string TensorShapeRep::DebugString() const { } else { strings::StrAppend(&s, dim); } + if (shape.get_expression(i) != nullptr) { + strings::StrAppend(&s, "<"); + strings::StrAppend(&s, ExprToString(shape.get_expression(i))); + strings::StrAppend(&s, ">"); + } } strings::StrAppend(&s, "]"); return s; @@ -787,6 +957,15 @@ string TensorShapeRep::DebugString(const TensorShapeProto& proto) { first = false; } strings::StrAppend(&s, "]"); + strings::StrAppend(&s, "<"); + first = true; + for (const auto& e : proto.expressions()) { + if (!first) strings::StrAppend(&s, ","); + auto exp = ExprFromProto(e); + strings::StrAppend(&s, ExprToString(exp)); + first = false; + } + strings::StrAppend(&s, ">"); return s; } @@ -950,6 +1129,7 @@ absl::Status PartialTensorShape::MergeWith(const PartialTensorShape& shape, return s; } } + result->set_expressions(shape.get_expressions()); return absl::OkStatus(); } diff --git a/tensorflow/core/framework/tensor_shape.h b/tensorflow/core/framework/tensor_shape.h index 0bcf1fc54af844..f2731f2cb444be 100644 --- a/tensorflow/core/framework/tensor_shape.h +++ b/tensorflow/core/framework/tensor_shape.h @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/statusor.h" +#include "xla/shape_dynexpr.h" namespace tensorflow { @@ -73,7 +74,33 @@ class TensorShapeRep { std::string DebugString() const; static std::string DebugString(const TensorShapeProto& proto); + void set_expression(int d, xla::DynExpr* expr); + + void AddExpression(xla::DynExpr* expr); + + // Set the array of dynamic multipliers. + void set_expressions(std::vector exprs); + + // Get the array of dynamic multipliers. + std::vector get_expressions() const { + return expressions_; + } + + // Return the multiplier for a specific dynamic dimension. + // -1 if the dimension is not dynamic. + xla::DynExpr* get_expression(int64_t dimension) const { + if (dimension < 0) return xla::DynExpr::_(-999); + const size_t dim = static_cast(dimension); + if (dim >= expressions_.size()) { + return xla::DynExpr::_(-999); + } + return expressions_[dim] != nullptr ? expressions_[dim] + : xla::DynExpr::_(-999); + } + protected: + std::vector expressions_; + // Constructable only via TensorShapeBase TensorShapeRep() = default; @@ -710,6 +737,7 @@ absl::Status TensorShape::AsEigenDSizesWithPaddingWithStatus( inline TensorShapeRep::TensorShapeRep(const TensorShapeRep& b) { num_elements_ = b.num_elements_; + expressions_ = b.expressions_; if (b.tag() != REP_OUT_OF_LINE) { memcpy(buf(), b.buf(), sizeof(u_.buf)); // memcpy above Implicitly does: @@ -723,6 +751,7 @@ inline TensorShapeRep::TensorShapeRep(const TensorShapeRep& b) { inline TensorShapeRep::TensorShapeRep(TensorShapeRep&& b) { num_elements_ = b.num_elements_; + expressions_ = b.expressions_; memcpy(buf(), b.buf(), sizeof(u_.buf)); // memcpy above Implicitly does: // set_ndims_byte(b.ndims_byte()); @@ -738,6 +767,8 @@ inline TensorShapeRep::~TensorShapeRep() { inline void TensorShapeRep::operator=(const TensorShapeRep& b) { num_elements_ = b.num_elements_; + expressions_ = b.expressions_; + if (tag() != REP_OUT_OF_LINE && b.tag() != REP_OUT_OF_LINE) { memcpy(buf(), b.buf(), sizeof(u_.buf)); // memcpy above implicitly also does: @@ -753,6 +784,8 @@ inline void TensorShapeRep::operator=(TensorShapeRep&& b) { DestructorOutOfLine(); } num_elements_ = b.num_elements_; + expressions_ = b.expressions_; + memcpy(buf(), b.buf(), sizeof(u_.buf)); // memcpy above Implicitly does: // set_ndims_byte(b.ndims_byte()); diff --git a/tensorflow/core/framework/tensor_shape.proto b/tensorflow/core/framework/tensor_shape.proto index 45d5b78ecbbc4c..f69b4228a7fb31 100644 --- a/tensorflow/core/framework/tensor_shape.proto +++ b/tensorflow/core/framework/tensor_shape.proto @@ -22,6 +22,10 @@ message TensorShapeProto { // Optional name of the tensor dimension. string name = 2; + //Only keep one symbolic expr. + //Symbolic expression for this dimension when size == -1. + //Allows tracking relationships between unknown dimensions. + ExpressionProto expr = 3; }; // Dimensions of the tensor, such as {"input", 30}, {"output", 40} @@ -43,4 +47,38 @@ message TensorShapeProto { // // If true, "dim.size()" must be 0. bool unknown_rank = 3; + + repeated ExpressionProto expressions = 4; + }; + +message ExpressionProto { + oneof node_type { + int32 constant_value = 1; // cons + int32 variable_id = 2; // var + AddNode add_node = 3; // exp + exp + SubNode sub_node = 4; // exp - exp + MulNode mul_node = 5; // exp * exp + DivNode div_node = 6; // exp / exp + } +} + +message AddNode { + ExpressionProto lhs = 1; + ExpressionProto rhs = 2; +} + +message SubNode { + ExpressionProto lhs = 1; + ExpressionProto rhs = 2; +} + +message MulNode { + ExpressionProto lhs = 1; + ExpressionProto rhs = 2; +} + +message DivNode { + ExpressionProto lhs = 1; + ExpressionProto rhs = 2; +} \ No newline at end of file diff --git a/tensorflow/core/framework/tensor_shape_expr.cc b/tensorflow/core/framework/tensor_shape_expr.cc new file mode 100644 index 00000000000000..37d1f33064957d --- /dev/null +++ b/tensorflow/core/framework/tensor_shape_expr.cc @@ -0,0 +1,196 @@ +#include "tensorflow/core/framework/tensor_shape_expr.h" + +namespace tensorflow { + +std::unique_ptr DimExpr::Cons(int64_t val) { + return std::make_unique(val); +} + +std::unique_ptr DimExpr::Var(int32_t id) { + return std::make_unique(id); +} + +std::string DimExpr::DebugString() const { + ExpressionProto proto; + ToProto(&proto); + return proto.DebugString(); +} + +static bool EqualsImpl(const DimExpr* a, const DimExpr* b) { + if (a == b) return true; + if (a == nullptr || b == nullptr) return false; + if (a->kind() != b->kind()) return false; + + switch (a->kind()) { + case DimExpr::Kind::kConstant: { + auto* ac = static_cast(a); + auto* bc = static_cast(b); + return ac->value() == bc->value(); + } + case DimExpr::Kind::kVariable: { + auto* av = static_cast(a); + auto* bv = static_cast(b); + return av->id() == bv->id(); + } + case DimExpr::Kind::kAdd: { + auto* aa = static_cast(a); + auto* ba = static_cast(b); + return EqualsImpl(aa->lhs(), ba->lhs()) && + EqualsImpl(aa->rhs(), ba->rhs()); + } + case DimExpr::Kind::kSub: { + auto* as = static_cast(a); + auto* bs = static_cast(b); + return EqualsImpl(as->lhs(), bs->lhs()) && + EqualsImpl(as->rhs(), bs->rhs()); + } + case DimExpr::Kind::kMul: { + auto* am = static_cast(a); + auto* bm = static_cast(b); + return EqualsImpl(am->lhs(), bm->lhs()) && + EqualsImpl(am->rhs(), bm->rhs()); + } + case DimExpr::Kind::kDiv: { + auto* ad = static_cast(a); + auto* bd = static_cast(b); + return EqualsImpl(ad->lhs(), bd->lhs()) && + EqualsImpl(ad->rhs(), bd->rhs()); + } + } + + return false; +} + +bool DimExpr::Equals(const DimExpr* a, const DimExpr* b) { + return EqualsImpl(a, b); +} + +std::unique_ptr DimExpr::FromProto(const ExpressionProto& proto) { + switch (proto.node_type_case()) { + case ExpressionProto::kConstantValue: + return DimExpr::Cons(proto.constant_value()); + case ExpressionProto::kVariableId: + return DimExpr::Var(proto.variable_id()); + case ExpressionProto::kAddNode: { + auto lhs = FromProto(proto.add_node().lhs()); + auto rhs = FromProto(proto.add_node().rhs()); + // Note: These are owning pointers, but ExprAdd takes raw pointers. + // The caller must manage lifetime appropriately. + return std::make_unique(lhs.release(), rhs.release()); + } + case ExpressionProto::kSubNode: { + auto lhs = FromProto(proto.sub_node().lhs()); + auto rhs = FromProto(proto.sub_node().rhs()); + return std::make_unique(lhs.release(), rhs.release()); + } + case ExpressionProto::kMulNode: { + auto lhs = FromProto(proto.mul_node().lhs()); + auto rhs = FromProto(proto.mul_node().rhs()); + return std::make_unique(lhs.release(), rhs.release()); + } + case ExpressionProto::kDivNode: { + auto lhs = FromProto(proto.div_node().lhs()); + auto rhs = FromProto(proto.div_node().rhs()); + return std::make_unique(lhs.release(), rhs.release()); + } + case ExpressionProto::NODE_TYPE_NOT_SET: + default: + return nullptr; + } +} + +DimExpr* SimplifyExpr(DimExpr* expr, + std::vector>* arena) { + if (!expr) return nullptr; + + auto own = [arena](std::unique_ptr e) -> DimExpr* { + DimExpr* ptr = e.get(); + arena->push_back(std::move(e)); + return ptr; + }; + + switch (expr->kind()) { + case DimExpr::Kind::kConstant: + case DimExpr::Kind::kVariable: + return expr; + + case DimExpr::Kind::kAdd: { + auto* add = static_cast(expr); + DimExpr* lhs = SimplifyExpr(add->lhs(), arena); + DimExpr* rhs = SimplifyExpr(add->rhs(), arena); + + // Constant folding + if (lhs->IsConstant() && rhs->IsConstant()) { + return own(DimExpr::Cons(lhs->ConstantValue() + rhs->ConstantValue())); + } + + // x + 0 → x + if (rhs->IsConstant() && rhs->ConstantValue() == 0) return lhs; + if (lhs->IsConstant() && lhs->ConstantValue() == 0) return rhs; + + return own(std::make_unique(lhs, rhs)); + } + + case DimExpr::Kind::kSub: { + auto* sub = static_cast(expr); + DimExpr* lhs = SimplifyExpr(sub->lhs(), arena); + DimExpr* rhs = SimplifyExpr(sub->rhs(), arena); + + // Constant folding + if (lhs->IsConstant() && rhs->IsConstant()) { + return own(DimExpr::Cons(lhs->ConstantValue() - rhs->ConstantValue())); + } + + // x - 0 → x + if (rhs->IsConstant() && rhs->ConstantValue() == 0) return lhs; + + return own(std::make_unique(lhs, rhs)); + } + + case DimExpr::Kind::kMul: { + auto* mul = static_cast(expr); + DimExpr* lhs = SimplifyExpr(mul->lhs(), arena); + DimExpr* rhs = SimplifyExpr(mul->rhs(), arena); + + // Constant folding + if (lhs->IsConstant() && rhs->IsConstant()) { + return own(DimExpr::Cons(lhs->ConstantValue() * rhs->ConstantValue())); + } + + // x * 1 → x + if (rhs->IsConstant() && rhs->ConstantValue() == 1) return lhs; + if (lhs->IsConstant() && lhs->ConstantValue() == 1) return rhs; + + // x * 0 → 0 + if (rhs->IsConstant() && rhs->ConstantValue() == 0) + return own(DimExpr::Cons(0)); + if (lhs->IsConstant() && lhs->ConstantValue() == 0) + return own(DimExpr::Cons(0)); + + return own(std::make_unique(lhs, rhs)); + } + + case DimExpr::Kind::kDiv: { + auto* div = static_cast(expr); + DimExpr* lhs = SimplifyExpr(div->lhs(), arena); + DimExpr* rhs = SimplifyExpr(div->rhs(), arena); + + // Constant folding (avoid div by zero) + if (lhs->IsConstant() && rhs->IsConstant()) { + int64_t r = rhs->ConstantValue(); + if (r != 0) { + return own(DimExpr::Cons(lhs->ConstantValue() / r)); + } + } + + // x / 1 → x + if (rhs->IsConstant() && rhs->ConstantValue() == 1) return lhs; + + return own(std::make_unique(lhs, rhs)); + } + } + + return expr; +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/tensor_shape_expr.h b/tensorflow/core/framework/tensor_shape_expr.h new file mode 100644 index 00000000000000..2979df64af639c --- /dev/null +++ b/tensorflow/core/framework/tensor_shape_expr.h @@ -0,0 +1,219 @@ +#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_EXPR_H_ +#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_EXPR_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/framework/tensor_shape.pb.h" + +namespace tensorflow { + +// Forward declarations +class Constant; +class Variable; +class ExprAdd; +class ExprSub; +class ExprMul; +class ExprDiv; + +// DimExpr: Base class for symbolic expressions representing dynamic dimension +// sizes. These expressions form a DAG that tracks how unknown dimensions relate +// to each other through arithmetic operations. +// +// The expression language: +// - Var(sym_id): A symbolic variable representing an unknown dimension +// - Const(k): A known constant value +// - Add/Sub/Mul/Div(lhs, rhs): Binary arithmetic operations +// +// INVARIANT: An unknown dimension is not just -1, it is -1 + Var(sym). +class DimExpr { + public: + enum class Kind : uint8_t { + kConstant, + kVariable, + kAdd, + kSub, + kMul, + kDiv, + }; + + virtual ~DimExpr() = default; + + virtual Kind kind() const = 0; + virtual void ToProto(ExpressionProto* proto) const = 0; + + virtual bool IsConstant() const { return false; } + virtual int64_t ConstantValue() const { return 0; } + + // Factory methods - return owning pointers + static std::unique_ptr Cons(int64_t val); + static std::unique_ptr Var(int32_t var_id); + + // Structural equality check + static bool Equals(const DimExpr* a, const DimExpr* b); + + // Build from proto (owns all returned nodes) + static std::unique_ptr FromProto(const ExpressionProto& proto); + + // Debug representation + std::string DebugString() const; + + protected: + DimExpr() = default; +}; + +// Constant expression node: represents a known integer value +class Constant final : public DimExpr { + public: + explicit Constant(int64_t value) : value_(value) {} + + Kind kind() const override { return Kind::kConstant; } + void ToProto(ExpressionProto* proto) const override { + proto->set_constant_value(value_); + } + + bool IsConstant() const override { return true; } + int64_t ConstantValue() const override { return value_; } + + int64_t value() const { return value_; } + + private: + int64_t value_; +}; + +// Variable expression node: represents a symbolic unknown dimension +class Variable final : public DimExpr { + public: + explicit Variable(int32_t id) : id_(id) {} + + Kind kind() const override { return Kind::kVariable; } + void ToProto(ExpressionProto* proto) const override { + proto->set_variable_id(id_); + } + + int32_t id() const { return id_; } + + private: + int32_t id_; +}; + +// Addition expression node +class ExprAdd final : public DimExpr { + public: + ExprAdd(DimExpr* lhs, DimExpr* rhs) : lhs_(lhs), rhs_(rhs) {} + + Kind kind() const override { return Kind::kAdd; } + void ToProto(ExpressionProto* proto) const override { + auto* add_msg = proto->mutable_add_node(); + lhs_->ToProto(add_msg->mutable_lhs()); + rhs_->ToProto(add_msg->mutable_rhs()); + } + + bool IsConstant() const override { + return lhs_->IsConstant() && rhs_->IsConstant(); + } + int64_t ConstantValue() const override { + return lhs_->ConstantValue() + rhs_->ConstantValue(); + } + + DimExpr* lhs() const { return lhs_; } + DimExpr* rhs() const { return rhs_; } + + private: + DimExpr* lhs_; + DimExpr* rhs_; +}; + +// Subtraction expression node +class ExprSub final : public DimExpr { + public: + ExprSub(DimExpr* lhs, DimExpr* rhs) : lhs_(lhs), rhs_(rhs) {} + + Kind kind() const override { return Kind::kSub; } + void ToProto(ExpressionProto* proto) const override { + auto* sub_msg = proto->mutable_sub_node(); + lhs_->ToProto(sub_msg->mutable_lhs()); + rhs_->ToProto(sub_msg->mutable_rhs()); + } + + bool IsConstant() const override { + return lhs_->IsConstant() && rhs_->IsConstant(); + } + int64_t ConstantValue() const override { + return lhs_->ConstantValue() - rhs_->ConstantValue(); + } + + DimExpr* lhs() const { return lhs_; } + DimExpr* rhs() const { return rhs_; } + + private: + DimExpr* lhs_; + DimExpr* rhs_; +}; + +// Multiplication expression node +class ExprMul final : public DimExpr { + public: + ExprMul(DimExpr* lhs, DimExpr* rhs) : lhs_(lhs), rhs_(rhs) {} + + Kind kind() const override { return Kind::kMul; } + void ToProto(ExpressionProto* proto) const override { + auto* mul_msg = proto->mutable_mul_node(); + lhs_->ToProto(mul_msg->mutable_lhs()); + rhs_->ToProto(mul_msg->mutable_rhs()); + } + + bool IsConstant() const override { + return lhs_->IsConstant() && rhs_->IsConstant(); + } + int64_t ConstantValue() const override { + return lhs_->ConstantValue() * rhs_->ConstantValue(); + } + + DimExpr* lhs() const { return lhs_; } + DimExpr* rhs() const { return rhs_; } + + private: + DimExpr* lhs_; + DimExpr* rhs_; +}; + +// Division expression node +class ExprDiv final : public DimExpr { + public: + ExprDiv(DimExpr* lhs, DimExpr* rhs) : lhs_(lhs), rhs_(rhs) {} + + Kind kind() const override { return Kind::kDiv; } + void ToProto(ExpressionProto* proto) const override { + auto* div_msg = proto->mutable_div_node(); + lhs_->ToProto(div_msg->mutable_lhs()); + rhs_->ToProto(div_msg->mutable_rhs()); + } + + bool IsConstant() const override { + return lhs_->IsConstant() && rhs_->IsConstant(); + } + int64_t ConstantValue() const override { + int64_t r = rhs_->ConstantValue(); + return (r == 0) ? 0 : lhs_->ConstantValue() / r; + } + + DimExpr* lhs() const { return lhs_; } + DimExpr* rhs() const { return rhs_; } + + private: + DimExpr* lhs_; + DimExpr* rhs_; +}; + +// Simplify an expression tree: constant folding and algebraic identities. +// Returns a NEW expression (does not mutate input). +// The arena parameter is used to allocate nodes that will be owned externally. +DimExpr* SimplifyExpr(DimExpr* expr, + std::vector>* arena); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_EXPR_H_ diff --git a/tensorflow/core/graph/subgraph.cc b/tensorflow/core/graph/subgraph.cc index bb47f37ef7fbe3..d1616e2396515e 100644 --- a/tensorflow/core/graph/subgraph.cc +++ b/tensorflow/core/graph/subgraph.cc @@ -79,6 +79,20 @@ absl::Status FeedInputs( TF_RETURN_IF_ERROR( feed_rewrites[i]->AddNode(g, {n, id.second}, &feed_node)); + // Set an attribute in _Arg node to indicate it has a batch dimension + auto node_attrs = n->attrs(); + const AttrValue* shape_attr = node_attrs.FindByString("_output_shapes"); + if (shape_attr && shape_attr->has_list()) { + const TensorShapeProto& shape = shape_attr->list().shape(0); + for (int i = 0; i < shape.dim_size(); ++i) { + if (shape.dim(i).size() == -1) { + feed_node->AddAttr("_dynamic_dim", i); + break; + } + } + // Keep _output_shapes for further runs of shape inference + feed_node->AddAttr("_output_shapes", *shape_attr); + } // Update name_index (*name_index)[feed_node->name()] = feed_node; // Duplicate control edges aren't allowed, but feed_node was *just* created diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 613b12bb18ae3a..6a24831fd8fc0f 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/framework/tensor_shape_expr.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/grappler/costs/utils.h" #include "tensorflow/core/grappler/mutable_graph_view.h" @@ -189,6 +190,9 @@ class DisjointSet { absl::Status Merge(Handle x, Handle y); const typename HandleToObject::Object GetMergedValue(Handle value); + // Returns a pointer that uniquely identifies the set containing `value`. + // This can be used as a stable key for associating metadata with a set. + void* RootId(Handle value) { return static_cast(Find(value)); } private: // All the handles that belong to the same set are part of the same tree, and @@ -1148,32 +1152,32 @@ class SymbolicShapeRefiner { } struct ShapeId { - const NodeDef* node; + std::string node_name; int port_id; friend bool operator==(const ShapeId& lhs, const ShapeId& rhs) { - return lhs.node == rhs.node && lhs.port_id == rhs.port_id; + return lhs.node_name == rhs.node_name && lhs.port_id == rhs.port_id; } template friend H AbslHashValue(H h, const ShapeId& s) { - return H::combine(std::move(h), s.node, s.port_id); + return H::combine(std::move(h), s.node_name, s.port_id); } }; struct DimId { - const NodeDef* node; + std::string node_name; int port_id; int dim_index; friend bool operator==(const DimId& lhs, const DimId& rhs) { - return lhs.node == rhs.node && lhs.port_id == rhs.port_id && + return lhs.node_name == rhs.node_name && lhs.port_id == rhs.port_id && lhs.dim_index == rhs.dim_index; } template friend H AbslHashValue(H h, const DimId& d) { - return H::combine(std::move(h), d.node, d.port_id, d.dim_index); + return H::combine(std::move(h), d.node_name, d.port_id, d.dim_index); } }; @@ -1391,7 +1395,7 @@ class SymbolicShapeRefiner { // Return the one ShapeHandle used to denote a fully unknown shape for a node // output. ShapeHandle GetUnknownOutputShape(const NodeDef* node, int index) { - ShapeId id{node, index}; + ShapeId id{node->name(), index}; auto it = unknown_shapes_.find(id); if (it != unknown_shapes_.end()) { return it->second; @@ -1405,16 +1409,38 @@ class SymbolicShapeRefiner { // node output. DimensionHandle GetUnknownOutputDim(const NodeDef* node, int index, int dim_id) { - DimId id{node, index, dim_id}; + DimId id{node->name(), index, dim_id}; auto it = unknown_dims_.find(id); if (it != unknown_dims_.end()) { return it->second; } InferenceContext* c = GetContext(node); - DimensionHandle dim = c->UnknownDim(); + int var_id = GetOrCreateStableVarId(id); + DimensionHandle dim; + if (node->op() == "_Arg") { + var_id *= -1; + // var_id would be minus when it's argument. + dim = c->UnknownDimWithExpr(DimExpr::Var(var_id)); + } else { + dim = c->UnknownDimWithExpr(DimExpr::Var(var_id)); + } + VLOG(1) << "[EXPR] GetUnknownOutputDim: node=" << node->name() + << " out=" << index << " dim=" << dim_id << " -> Var(" << var_id + << ")"; + // Create an unknown dim with Var(var_id) expression. unknown_dims_[id] = dim; return dim; } + // Get or create a stable integer variable ID for a given DimId. + int GetOrCreateStableVarId(const DimId& id) { + auto it = stable_var_ids_.find(id); + if (it != stable_var_ids_.end()) { + return it->second; + } + int var_id = next_var_id_++; + stable_var_ids_[id] = var_id; + return var_id; + } // Returns true if all the output tensors have known values. bool AllOutputValuesKnown(NodeContext* c) { @@ -1712,7 +1738,50 @@ class SymbolicShapeRefiner { // but instantiate a new UnknownDim to prevent incorrect symbolic shape // inference through UnknownDim from Const. InferenceContext* ic = c->inference_context.get(); + const std::string& op = node.op(); + const bool is_bin = + (op == "Sub" || op == "Add" || op == "Mul" || op == "Div"); if (!is_fed) { + if (is_bin) { + if (c->input_tensors_as_shapes_to_propagate.size() < 2) + return absl::OkStatus(); + auto va = c->input_tensors_as_shapes_to_propagate[0]; + auto vb = c->input_tensors_as_shapes_to_propagate[1]; + + if (va.SameHandle(tensorflow::shape_inference::ShapeHandle()) || + vb.SameHandle(tensorflow::shape_inference::ShapeHandle())) { + return absl::OkStatus(); + } + + if (!ic->RankKnown(va) || !ic->RankKnown(vb)) return absl::OkStatus(); + if (ic->Rank(va) != ic->Rank(vb)) return absl::OkStatus(); + + std::vector out_elems; + out_elems.reserve(ic->Rank(va)); + + for (int i = 0; i < ic->Rank(va); ++i) { + auto da = ic->Dim(va, i); + auto db = ic->Dim(vb, i); + + tensorflow::shape_inference::DimensionHandle r; + if (op == "Sub") + TF_RETURN_IF_ERROR(ic->Subtract(da, db, &r)); + else if (op == "Add") + TF_RETURN_IF_ERROR(ic->Add(da, db, &r)); + else if (op == "Mul") + TF_RETURN_IF_ERROR(ic->Multiply(da, db, &r)); + else + TF_RETURN_IF_ERROR( + ic->Divide(da, db, /*evenly_divisible=*/false, &r)); + out_elems.push_back(r); + } + c->output_tensors_as_shapes.resize(1); + c->output_tensors_as_shapes[0] = ic->MakeShape(out_elems); + // @TODO: Check if we need to do anything with output_tensor_protos. + // S.t c->output_tensor_protos[0] = nullptr; + return absl::OkStatus(); + } + if (IsConstant(node)) { const TensorProto& tensor_proto = node.attr().at("value").tensor(); c->output_tensor_protos.resize(1); @@ -1759,7 +1828,7 @@ class SymbolicShapeRefiner { for (int i = 0; i < c->inference_context->num_inputs(); ++i) { c->output_tensors_as_shapes[i] = c->inference_context->input(i); } - } else if (node.op() == "ConcatV2") { + } else if (op == "ConcatV2") { bool valid = true; ShapeHandle result; for (int i = 0; i < ic->num_inputs() - 1; ++i) { @@ -1919,6 +1988,100 @@ class SymbolicShapeRefiner { return absl::OkStatus(); } + absl::Status CanonicalizeOutputDims(const NodeDef* node) { + NodeContext* ctx = GetNodeContext(node); + if (!ctx) return absl::OkStatus(); + + InferenceContext* ic = ctx->inference_context.get(); + for (int out = 0; out < ic->num_outputs(); ++out) { + ShapeHandle s = ic->output(out); + + if (!ic->RankKnown(s)) { + bool recovered_rank = false; + auto it = node->attr().find("_output_shapes"); + if(it != node->attr().end() && out < it->second.list().shape_size()){ + it = node->attr().find("shape"); + } + if (it != node->attr().end() && out < it->second.list().shape_size()) { + const TensorShapeProto& proto = it->second.list().shape(out); + if (!proto.unknown_rank()) { + std::vector dims; + dims.reserve(proto.dim_size()); + + for (int d = 0; d < proto.dim_size(); ++d) { + int64_t size = proto.dim(d).size(); + if (size >= 0) { + dims.push_back(ic->MakeDim(size)); + } else { + dims.push_back(GetUnknownOutputDim(node, out, d)); + } + } + s = ic->MakeShape(dims); + ic->set_output(out, s); + recovered_rank = true; + } + } + + if (!recovered_rank && node->op() == "_Arg") { + DimensionHandle d0 = GetUnknownOutputDim(node, out, /*dim_id=*/0); + ShapeHandle vec = ic->MakeShape({d0}); + ic->set_output(out, vec); + s = vec; + recovered_rank = true; + } + + if (!recovered_rank) { + VLOG(1) << "RANK still unknown. " << node->name(); + continue; + } + } + + if (!ic->RankKnown(s)) { + continue; + } + + bool changed = false; + std::vector dims; + dims.reserve(ic->Rank(s)); + for (int d = 0; d < ic->Rank(s); ++d) { + DimensionHandle dim = ic->Dim(s, d); + const int64_t v = ic->Value(dim); + // Keep concrete dims. + if (v >= 0) { + dims.push_back(dim); + continue; + } + // If already tagged with expr, keep it. + auto* dim_expr = ic->GetDimExpr(dim); + if (dim_expr != nullptr) { + dims.push_back(dim); + continue; + } + auto output_shapes_it = node->attr().find("_output_shapes"); + if (output_shapes_it != node->attr().end() && + out < output_shapes_it->second.list().shape_size() && + d < output_shapes_it->second.list().shape(out).dim_size()) { + const int64_t annotated_size = + output_shapes_it->second.list().shape(out).dim(d).size(); + if (annotated_size >= 0) { + changed = true; + dims.push_back(ic->MakeDim(annotated_size)); + continue; + } + } + // Canonicalize ALL unknown dims. + DimensionHandle canon = GetUnknownOutputDim(node, out, d); + changed |= !dim.SameHandle(canon); + dims.push_back(canon); + } + if (changed) { + ShapeHandle new_s = ic->MakeShape(dims); + ic->set_output(out, new_s); + } + } + return absl::OkStatus(); + } + absl::Status InferShapes(const NodeDef& node, NodeContext* c) { // Infer the shapes of output tensors. if (!c->op_data || c->op_data->shape_inference_fn == nullptr || @@ -1940,8 +2103,8 @@ class SymbolicShapeRefiner { status.Update(SetUnknownShape(&node, output_port)); } } - // Update NodeContext output fields after shape inference function runs. + status.Update(CanonicalizeOutputDims(&node)); status.Update(MaybeUpdateNodeContextOutput(node, is_fed, c)); return status; @@ -2048,6 +2211,9 @@ class SymbolicShapeRefiner { absl::flat_hash_map node_to_context_; absl::flat_hash_map unknown_shapes_; absl::flat_hash_map unknown_dims_; + // Stable variable IDs for canonical dimension symbols. + absl::flat_hash_map stable_var_ids_; + int next_var_id_ = 1; // Store function instantiations only for valid function. If function // instantiation failed it will have an `absl::nullopt`. absl::flat_hash_map> @@ -2080,8 +2246,9 @@ class SymbolicShapeManager { if (InferenceContext::Rank(s1) > 0 && InferenceContext::Rank(s2) > 0) { CHECK_EQ(InferenceContext::Rank(s1), InferenceContext::Rank(s2)); for (int i = 0; i < InferenceContext::Rank(s1); ++i) { - TF_RETURN_IF_ERROR(dims_.Merge(InferenceContext::DimKnownRank(s1, i), - InferenceContext::DimKnownRank(s2, i))); + TF_RETURN_IF_ERROR( + MergeDimsWithExpr(InferenceContext::DimKnownRank(s1, i), + InferenceContext::DimKnownRank(s2, i))); } } return absl::OkStatus(); @@ -2090,7 +2257,7 @@ class SymbolicShapeManager { if (!d1.IsSet() || !d2.IsSet()) { return absl::OkStatus(); } - return dims_.Merge(d1, d2); + return MergeDimsWithExpr(d1, d2); } void AsTensorProperties(const ShapeHandle& shape, const DataType& type, @@ -2104,7 +2271,19 @@ class SymbolicShapeManager { shape_inference::DimensionHandle dim = InferenceContext::DimKnownRank(actual_shape, j); int64_t d = dims_.GetMergedValue(dim); - properties->mutable_shape()->add_dim()->set_size(d); + auto* out_dim = properties->mutable_shape()->add_dim(); + out_dim->set_size(d < 0 ? -1 : d); + void* root = dims_.RootId(dim); + DimExpr* expr = nullptr; + if (auto it = dim_root_expr_.find(root); it != dim_root_expr_.end()) { + expr = it->second; + } else { + expr = ExprForDim(dim); + } + if (expr != nullptr) { + expr->ToProto(out_dim->mutable_expr()); + // TODO: Apply simplification? + } } } } @@ -2132,7 +2311,139 @@ class SymbolicShapeManager { } private: + // Get the variable ID from an expression, or -1 if not a variable. + static int32_t GetVarId(const DimExpr* e) { + if (!e || e->kind() != DimExpr::Kind::kVariable) return -1; + return static_cast(e)->id(); + } + + static bool IsConst(const DimExpr* e) { + return e && e->kind() == DimExpr::Kind::kConstant; + } + + static bool IsVar(const DimExpr* e) { + return e && e->kind() == DimExpr::Kind::kVariable; + } + + static bool IsPlaceHolder(const DimExpr* e) { + if (!e) return false; + if (e->kind() != DimExpr::Kind::kVariable) return false; + return static_cast(e)->id() < 0; + } + + static bool IsCompound(const DimExpr* e) { + if (!e) return false; + switch (e->kind()) { + case DimExpr::Kind::kAdd: + case DimExpr::Kind::kSub: + case DimExpr::Kind::kMul: + case DimExpr::Kind::kDiv: + return true; + default: + return false; + } + } + + // Ranking: Const > Arg_ > Compound > Var > null + static int InfoScore(const DimExpr* e) { + if (!e) return 0; + if (IsConst(e)) return 4; + if (IsPlaceHolder(e)) return 3; + if (IsCompound(e)) return 2; + if (IsVar(e)) return 1; + return 1; // fallback (shouldn't happen) + } + + // Prefer "more informative" but keep deterministic tie-break. + static DimExpr* PreferMoreInformative(DimExpr* a, DimExpr* b) { + if (a == b) return a; + const int sa = InfoScore(a); + const int sb = InfoScore(b); + if (sa > sb) return a; + if (sb > sa) return b; + // Same score: keep stable choice. + return a; + } + + // Get the expr pointer from a dimension handle (accesses private member). + static DimExpr* GetExprFromDimHandle(const DimensionHandle& d) { + if (!d.IsSet()) return nullptr; + return d->expr_; + } + + DimExpr* ExprForDim(const DimensionHandle& d) { + if (!d.IsSet()) return nullptr; + if (DimExpr* expr = GetExprFromDimHandle(d)) { + return expr; + } + if (!InferenceContext::ValueKnown(d)) { + return nullptr; + } + const int64_t value = InferenceContext::Value(d); + auto it = const_exprs_.find(value); + if (it != const_exprs_.end()) { + return it->second.get(); + } + auto expr = DimExpr::Cons(value); + DimExpr* expr_ptr = expr.get(); + const_exprs_.emplace(value, std::move(expr)); + return expr_ptr; + } + + absl::Status MergeDimsWithExpr(DimensionHandle d1, DimensionHandle d2) { + if (!d1.IsSet() || !d2.IsSet()) return absl::OkStatus(); + + void* r1 = dims_.RootId(d1); + void* r2 = dims_.RootId(d2); + + // Fetch best-known expr for each set. + auto get_best = [&](void* r, DimensionHandle d) -> DimExpr* { + auto it = dim_root_expr_.find(r); + if (it != dim_root_expr_.end()) return it->second; + return ExprForDim(d); // may be null + }; + + DimExpr* e1 = get_best(r1, d1); + DimExpr* e2 = get_best(r2, d2); + + // If already in same UF set, just keep the most informative expr. + if (r1 == r2) { + DimExpr* existing = nullptr; + if (auto it = dim_root_expr_.find(r1); it != dim_root_expr_.end()) { + existing = it->second; + } + DimExpr* chosen = PreferMoreInformative(existing, + PreferMoreInformative(e1, e2)); + if (chosen) dim_root_expr_[r1] = chosen; // keep or upgrade + return absl::OkStatus(); + } + + // Perform UF merge (rank + path compression inside DisjointSet). + TF_RETURN_IF_ERROR(dims_.Merge(d1, d2)); + + // New root after merge. + void* new_root = dims_.RootId(d1); + + // Choose best expr across both sets. + DimExpr* chosen = PreferMoreInformative(e1, e2); + + // Remove stale root keys (only the old roots). + dim_root_expr_.erase(r1); + dim_root_expr_.erase(r2); + + // Preserve any expr already stored at new_root (rare but safe). + if (auto it = dim_root_expr_.find(new_root); it != dim_root_expr_.end()) { + chosen = PreferMoreInformative(it->second, chosen); + } + + if (chosen) dim_root_expr_[new_root] = chosen; + + return absl::OkStatus(); + } DisjointSet shapes_; + absl::flat_hash_map> const_exprs_; + // Map from union-find root pointer to the best expression for that set. + absl::flat_hash_map dim_root_expr_; DisjointSet dims_; }; diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index a5a48347f07517..c468801a6a74db 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -741,6 +741,16 @@ bool IsBiasSemanticAdd(const RemapperContext& ctx, return false; } +bool MaybeCopyOutputShapesAttr(const NodeDef& from, NodeDef* fused_op) { + const string output_shape_attr_name = "_output_shapes"; + if (from.attr().count(output_shape_attr_name) > 0) { + auto shape_attrs = from.attr().at(output_shape_attr_name); + AddNodeAttr(output_shape_attr_name, shape_attrs, fused_op); + return true; + } + return false; +} + void AddInputShapesAttr(const RemapperContext& ctx, int node_index) { auto mutable_node = ctx.graph_view.graph()->mutable_node(node_index); @@ -3334,6 +3344,7 @@ absl::Status AddFusedContractionNode(RemapperContext* ctx, } SetFusedOpAttributes(&fused_op, {"BiasAdd"}); + MaybeCopyOutputShapesAttr(bias_add, &fused_op); utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); absl::Status status; mutation->AddNode(std::move(fused_op), &status); @@ -3439,6 +3450,7 @@ absl::Status AddFusedContractionNode( } SetFusedOpAttributes(&fused_op, {"BiasAdd", activation.op()}); + MaybeCopyOutputShapesAttr(activation, &fused_op); utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); absl::Status status; @@ -3630,6 +3642,7 @@ absl::Status AddFusedContractionNode( } SetFusedOpAttributes(&contraction_node, {"BiasAdd", "Add"}, 2); + MaybeCopyOutputShapesAttr(add, &contraction_node); utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); absl::Status status; @@ -3729,6 +3742,7 @@ absl::Status AddFusedContractionNode( } SetFusedOpAttributes(&fused_conv, {"BiasAdd", "Add", activation.op()}, 2); + MaybeCopyOutputShapesAttr(add, &fused_conv); utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); absl::Status status; diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index 864855de1d69f6..7f131e3c23e5f4 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/gradients.h" #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/memory_types.h" +#include "tensorflow/core/framework/batch_size_resource.h" #include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/full_type.pb.h" #include "tensorflow/core/framework/full_type_util.h" @@ -42,6 +43,13 @@ static constexpr const char* const kGradientOp = ArgOp::ArgOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_)); + + Status s = ctx->GetAttr("_dynamic_dim", &dynamic_dim_); + if (IsNotFound(s)) { + dynamic_dim_ = -1; + } else { + OP_REQUIRES_OK(ctx, s); + } } void ArgOp::Compute(OpKernelContext* ctx) { @@ -59,16 +67,43 @@ void ArgOp::Compute(OpKernelContext* ctx) { } }; + Tensor t; if (frame->CanConsumeArg(index_)) { - Tensor val; - frame->ConsumeArg(index_, &val); - OP_REQUIRES_OK(ctx, validate_type(val)); - ctx->set_output(0, std::move(val)); + frame->ConsumeArg(index_, &t); + OP_REQUIRES_OK(ctx, validate_type(t)); + ctx->set_output(0, std::move(t)); + val = &t; } else { OP_REQUIRES_OK(ctx, frame->GetArg(index_, &val)); OP_REQUIRES_OK(ctx, validate_type(*val)); ctx->set_output(0, *val); } + if (dynamic_dim_ >= 0) { + BatchSizeResource* bsr = nullptr; + ScopedStepContainer* step_container = ctx->step_container(); + + OP_REQUIRES_OK(ctx, step_container->LookupOrCreate( + ctx->resource_manager(), BatchSizeResourceName, &bsr, + [](BatchSizeResource** ret) -> Status { + *ret = new BatchSizeResource(); + return OkStatus(); + })); + + const int64_t batch_size = val->dim_size(dynamic_dim_); + VLOG(1) << "Found batch_size in dimension #" << dynamic_dim_; + if (bsr->GetBatchSize() == 0) { + bsr->SetBatchSize(batch_size); + VLOG(1) << "Set batch_size from 0 to " << batch_size + << ". step_id: " << ctx->step_id(); + } else if (bsr->GetBatchSize() != batch_size) { + VLOG(1) << "Warning: Set batch_size from " << bsr->GetBatchSize() + << ". step_id: " << ctx->step_id(); + bsr->SetBatchSize(batch_size); + } else { + VLOG(1) << "batch_size already set to " << batch_size; + } + bsr->Unref(); + } } RetvalOp::RetvalOp(OpKernelConstruction* ctx) : OpKernel(ctx) { diff --git a/tensorflow/core/kernels/function_ops.h b/tensorflow/core/kernels/function_ops.h index 552e1e6c515e3b..fed05ead2007ad 100644 --- a/tensorflow/core/kernels/function_ops.h +++ b/tensorflow/core/kernels/function_ops.h @@ -38,6 +38,7 @@ class ArgOp : public OpKernel { private: int index_; DataType dtype_; + int dynamic_dim_; ArgOp(const ArgOp&) = delete; void operator=(const ArgOp&) = delete; @@ -54,6 +55,7 @@ class RetvalOp : public OpKernel { private: int index_; DataType dtype_; + int dynamic_dim_; RetvalOp(const RetvalOp&) = delete; void operator=(const RetvalOp&) = delete; diff --git a/tensorflow/core/kernels/padding_fifo_queue.cc b/tensorflow/core/kernels/padding_fifo_queue.cc index 3b50099fb9997c..a7dee8ab1d3fc8 100644 --- a/tensorflow/core/kernels/padding_fifo_queue.cc +++ b/tensorflow/core/kernels/padding_fifo_queue.cc @@ -400,7 +400,14 @@ std::vector PaddingFIFOQueue::ConvertShapesPartialDimensionsToZero( for (size_t i = 0; i < shapes.size(); ++i) { const PartialTensorShape& partial = partial_shapes[i]; TensorShape& shape = shapes[i]; - for (int64_t s : partial.dim_sizes()) shape.AddDim(s < 0 ? 0 : s); + for (int d = 0; d < partial.dims(); ++d) { + shape.AddDim(partial.dim_size(d) < 0 ? 0 : partial.dim_size(d)); + xla::DynExpr* expr = partial.get_expression(d); + if (expr != nullptr && expr->is_constant() && expr->get_val() < 0) { + expr = xla::DynExpr::zero; + } + shape.AddExpression(expr); + } } return shapes; } diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc index 4523126f9a71bd..8dc2b49e87e09a 100644 --- a/tensorflow/core/kernels/strided_slice_op.cc +++ b/tensorflow/core/kernels/strided_slice_op.cc @@ -309,6 +309,8 @@ class StridedSliceAssignOp : public OpKernel { bool is_simple_slice = true; absl::InlinedVector begin; absl::InlinedVector end; + absl::InlinedVector begin_expr; + absl::InlinedVector end_expr; absl::InlinedVector strides; Tensor* old_lhs = nullptr; @@ -353,7 +355,7 @@ class StridedSliceAssignOp : public OpKernel { old_lhs->shape(), begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask, &processing_shape, &final_shape, &is_identity, &is_simple_slice, &slice_dim0, - &begin, &end, &strides, &shape_spec)); + &begin, &end, &strides, &begin_expr, &end_expr, &shape_spec)); if (processing_shape.num_elements() > 0) { const Tensor& input = context->input(4); diff --git a/tensorflow/core/util/strided_slice_op.cc b/tensorflow/core/util/strided_slice_op.cc index 93c5a7e9818ae2..29c5d56efec504 100644 --- a/tensorflow/core/util/strided_slice_op.cc +++ b/tensorflow/core/util/strided_slice_op.cc @@ -22,6 +22,8 @@ limitations under the License. #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/lib/core/status.h" +#include "xla/shape.h" + namespace tensorflow { namespace { @@ -55,6 +57,8 @@ struct StridedSliceDenseSpec { bool end_valid; absl::InlinedVector& begin; absl::InlinedVector& end; + absl::InlinedVector& begin_expr; + absl::InlinedVector& end_expr; absl::InlinedVector& strides; // This vector helps construct the final shape of the slice. // The final tensor is reduced in rank whenever a single index e.g. foo[3] @@ -97,6 +101,8 @@ static absl::Status BuildDenseSpec(const StridedSliceSparseSpec& sparse, // to remove any ellipsis dense->begin.resize(dense->dims); dense->end.resize(dense->dims); + dense->begin_expr.resize(dense->dims); + dense->end_expr.resize(dense->dims); dense->strides.resize(dense->dims); dense->input_shape_gather_indices_sparse.resize(dense->dims); // What indices to get the final shape from. @@ -127,6 +133,8 @@ static absl::Status BuildDenseSpec(const StridedSliceSparseSpec& sparse, for (; full_index < next_index; full_index++) { // new_axis' aren't real axis so you have to skip dense->begin[full_index] = dense->end[full_index] = 0; + dense->begin_expr[full_index] = dense->end_expr[full_index] = + xla::DynExpr::zero; dense->strides[full_index] = 1; dense->begin_mask |= (1 << full_index); dense->end_mask |= (1 << full_index); @@ -150,9 +158,13 @@ static absl::Status BuildDenseSpec(const StridedSliceSparseSpec& sparse, // Gather slicing spec into appropriate index if (begin_flat != nullptr) { dense->begin[full_index] = internal::SubtleMustCopy(begin_flat[i]); + dense->begin_expr[full_index] = + xla::DynExpr::_(dense->begin[full_index]); } if (end_flat != nullptr) { dense->end[full_index] = internal::SubtleMustCopy(end_flat[i]); + dense->end_expr[full_index] = + xla::DynExpr::_(dense->end[full_index]); } dense->strides[full_index] = internal::SubtleMustCopy(strides_flat[i]); @@ -193,10 +205,29 @@ absl::Status ValidateStridedSliceOp( absl::InlinedVector* begin, absl::InlinedVector* end, absl::InlinedVector* strides, + absl::InlinedVector* begin_expr, + absl::InlinedVector* end_expr, StridedSliceShapeSpec* shape_spec) { + absl::InlinedVector b; + absl::InlinedVector e; + + // HACK + if (begin_expr == nullptr) { + for (int i : *begin) { + b.push_back(xla::DynExpr::_(i)); + } + begin_expr = &b; + } + if (end_expr == nullptr) { + for (int i : *end) { + e.push_back(xla::DynExpr::_(i)); + } + end_expr = &e; + } + if (input_shape.unknown_rank()) { // Note: If the rank is unknown, "input_shape.dims()" is -1. - return errors::InvalidArgument("Unexpected input_shape with unknown rank"); + return errors::InvalidArgument("Unexpected input_shape with unknown rank"); } const bool begin_is_wrong = @@ -271,6 +302,7 @@ absl::Status ValidateStridedSliceOp( // we need to produce the missing begin_mask for the first two // dimensions i.e. from begin_mask_spec=0, end_mask_spec=2 // we achieve begin_mask=6, end_mask=7 + StridedSliceDenseSpec dense_spec = {input_shape.dims(), 0 /* begin_mask */, 0 /* end_mask */, @@ -278,6 +310,8 @@ absl::Status ValidateStridedSliceOp( false /* end_valid */, *begin, *end, + *begin_expr, + *end_expr, *strides}; if (strides_tensor.dtype() == DT_INT32) { @@ -301,12 +335,19 @@ absl::Status ValidateStridedSliceOp( int64_t& end_i = (*end)[i]; int64_t& stride_i = (*strides)[i]; int64_t dim_i = input_shape.dim_size(i); + auto dim_exprs = input_shape.get_expressions(); + + xla::DynExpr* dim_i_expr = + i < dim_exprs.size() ? dim_exprs[i] : xla::DynExpr::_(dim_i); + if (stride_i == 0) { return errors::InvalidArgument("strides[", i, "] must be non-zero"); } bool shrink_i = (dense_spec.shrink_axis_mask & (1 << i)); if (dim_i == -1) { processing_shape->AddDim(shrink_i ? 1 : -1); + processing_shape->AddExpression(shrink_i ? xla::DynExpr::_(1) + : xla::DynExpr::_(-1)); continue; } @@ -315,15 +356,36 @@ absl::Status ValidateStridedSliceOp( const std::array valid_range = { {stride_i > 0 ? 0 : -1, stride_i > 0 ? dim_i : dim_i - 1}}; + const std::array valid_range_expr = { + {stride_i > 0 ? xla::DynExpr::zero : xla::DynExpr::_(-1), + stride_i > 0 ? dim_i_expr : (*dim_i_expr - *xla::DynExpr::one)->s()}}; + auto canonical = [stride_i, dim_i, masks, valid_range](int64_t x, int c) { if (masks[c]) { return stride_i > 0 ? valid_range[c] : valid_range[(c + 1) & 1]; } else { int64_t x_fwd = x < 0 ? dim_i + x : x; // make negative indices positive - return x_fwd < valid_range[0] - ? valid_range[0] - : x_fwd > valid_range[1] ? valid_range[1] : x_fwd; + return x_fwd < valid_range[0] ? valid_range[0] + : x_fwd > valid_range[1] ? valid_range[1] + : x_fwd; + } + }; + auto canonical_expr = [stride_i, dim_i, masks, valid_range, + valid_range_expr, dim_i_expr](int64_t x, int c) { + if (masks[c]) { + return stride_i > 0 ? valid_range_expr[c] + : valid_range_expr[(c + 1) & 1]; + } else { + int64_t x_fwd = + x < 0 ? dim_i + x : x; // make negative indices positive + xla::DynExpr* x_expr = xla::DynExpr::_(x); + xla::DynExpr* x_fwd_expr = + x < 0 ? (*dim_i_expr + *x_expr) + : x_expr; // make negative indices positive + return x_fwd < valid_range[0] ? valid_range_expr[0] + : x_fwd > valid_range[1] ? valid_range_expr[1] + : x_fwd_expr; } }; if (shrink_i && stride_i <= 0) { @@ -341,15 +403,30 @@ absl::Status ValidateStridedSliceOp( // and canonical puts these to n-1 and 0, which implies a degenerate // interval. Fortunately, it is now safe to re-create end as begin+1. int64_t x_fwd = begin_i < 0 ? dim_i + begin_i : begin_i; + xla::DynExpr* x_fwd_expr = begin_i < 0 + ? (*dim_i_expr + *(*begin_expr)[i])->s() + : (*begin_expr)[i]; begin_i = x_fwd; end_i = begin_i + 1; + + (*begin_expr)[i] = x_fwd_expr; + (*end_expr)[i] = (*(*begin_expr)[i] + *xla::DynExpr::one)->s(); + if (x_fwd < 0 || x_fwd >= dim_i) { return errors::InvalidArgument( "slice index ", begin_i, " of dimension ", i, " out of bounds."); } } else { - begin_i = canonical(begin_i, 0); - end_i = canonical(end_i, 1); + const int64_t begin_raw = begin_i; + const int64_t end_raw = end_i; + begin_i = canonical(begin_raw, 0); + end_i = canonical(end_raw, 1); + if (begin_expr) { + (*begin_expr)[i] = canonical_expr(begin_raw, 0)->s(); + } + if (end_expr) { + (*end_expr)[i] = canonical_expr(end_raw, 1)->s(); + } } // Update optimization values bool take_all_in_dimension = @@ -362,14 +439,17 @@ absl::Status ValidateStridedSliceOp( } // Compute the processing shape (the intermediate Eigen will produce) int64_t interval_length; + xla::DynExpr* interval_length_expr; bool known_interval = false; if (dense_spec.begin_valid && dense_spec.end_valid) { interval_length = end_i - begin_i; + interval_length_expr = (*(*end_expr)[i] - *(*begin_expr)[i])->s(); known_interval = true; } else if (shrink_i) { // The dimension is still known as 1 for the processing_shape, but will be // discarded for the final shape. interval_length = 1; + interval_length_expr = xla::DynExpr::one; known_interval = true; } else if (begin_and_end_masked) { // Even if we don't have values for begin or end, we do know that this @@ -378,25 +458,34 @@ absl::Status ValidateStridedSliceOp( if (dim_i >= 0) { if (stride_i < 0) { interval_length = -dim_i; + interval_length_expr = (-1 * (*dim_i_expr))->s(); } else { interval_length = dim_i; + interval_length_expr = dim_i_expr; } known_interval = true; } } if (known_interval) { int64_t size_i; + xla::DynExpr* size_i_expr; // Hold zero if the interval is degenerate, otherwise account for // remainder if (interval_length == 0 || ((interval_length < 0) != (stride_i < 0))) { size_i = 0; + size_i_expr = xla::DynExpr::zero; } else { size_i = interval_length / stride_i + (interval_length % stride_i != 0 ? 1 : 0); + size_i_expr = *(*interval_length_expr / stride_i) + + *(interval_length % stride_i != 0 ? xla::DynExpr::one + : xla::DynExpr::zero); } processing_shape->AddDim(size_i); + processing_shape->AddExpression(size_i_expr->s()); } else { processing_shape->AddDim(-1); + processing_shape->AddExpression(xla::DynExpr::_(-1)); } } @@ -425,12 +514,15 @@ absl::Status ValidateStridedSliceOp( dense_spec.final_shape_gather_indices_sparse[dense_dim]; if (gather_index >= 0) { final_shape->AddDim(processing_shape->dim_size(gather_index)); + final_shape->AddExpression( + processing_shape->get_expression(gather_index)); if (shape_spec != nullptr) { shape_spec->output_to_sparse_mapping.push_back(sparse_index); shape_spec->output_to_processing_mapping.push_back(gather_index); } } else if (gather_index == kNewAxis) { final_shape->AddDim(1); + final_shape->AddExpression(xla::DynExpr::one); if (shape_spec != nullptr) { shape_spec->output_to_sparse_mapping.push_back(-1); shape_spec->output_to_processing_mapping.push_back(-1); @@ -451,6 +543,8 @@ absl::Status ValidateStridedSliceOp( absl::InlinedVector* begin, absl::InlinedVector* end, absl::InlinedVector* strides, + absl::InlinedVector* begin_expr, + absl::InlinedVector* end_expr, StridedSliceShapeSpec* shape_spec) { // Validate with PartialTensorShape output PartialTensorShape partial_processing_shape, partial_final_shape; @@ -458,7 +552,8 @@ absl::Status ValidateStridedSliceOp( begin_tensor, end_tensor, strides_tensor, input_shape, begin_mask_spec, end_mask_spec, ellipsis_mask, new_axis_mask, shrink_axis_mask, &partial_processing_shape, &partial_final_shape, is_identity, - is_simple_slice, slice_dim0, begin, end, strides, shape_spec)); + is_simple_slice, slice_dim0, begin, end, strides, begin_expr, end_expr, + shape_spec)); // Verify that the output shapes are fully known if (!partial_processing_shape.AsTensorShape(processing_shape) || diff --git a/tensorflow/core/util/strided_slice_op.h b/tensorflow/core/util/strided_slice_op.h index 356b77a2a0f5b6..3e55d9bb384481 100644 --- a/tensorflow/core/util/strided_slice_op.h +++ b/tensorflow/core/util/strided_slice_op.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "xla/shape.h" namespace tensorflow { @@ -74,6 +75,8 @@ absl::Status ValidateStridedSliceOp( absl::InlinedVector* begin, absl::InlinedVector* end, absl::InlinedVector* strides, + absl::InlinedVector* begin_expr = nullptr, + absl::InlinedVector* end_expr = nullptr, StridedSliceShapeSpec* shape_spec = nullptr); // Same as above, but the outputs are TensorShape, not PartialTensorShape @@ -87,6 +90,8 @@ absl::Status ValidateStridedSliceOp( absl::InlinedVector* begin, absl::InlinedVector* end, absl::InlinedVector* strides, + absl::InlinedVector* begin_expr = nullptr, + absl::InlinedVector* end_expr = nullptr, StridedSliceShapeSpec* shape_spec = nullptr); // Simple class for determining if it is possible to broadcast a tensor to a diff --git a/tensorflow/tools/toolchains/python/python_repo.bzl b/tensorflow/tools/toolchains/python/python_repo.bzl index 2af9b29d7af20b..158fd77bd8038b 100644 --- a/tensorflow/tools/toolchains/python/python_repo.bzl +++ b/tensorflow/tools/toolchains/python/python_repo.bzl @@ -21,6 +21,7 @@ TF_PYTHON_VERSION = "{}" HERMETIC_PYTHON_VERSION = "{}" WHEEL_NAME = "{}" WHEEL_COLLAB = "{}" +USE_PYWRAP_RULES = "False" """ def _python_repository_impl(repository_ctx): diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index e94208f90f4537..69d78b710885cb 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -469,6 +469,7 @@ cc_library( "layout_util.h", "primitive_util.h", "shape.h", + "shape_dynexpr.h", "shape_partition.h", "shape_util.h", ], diff --git a/third_party/xla/xla/backends/cpu/codegen/kernel_api_ir_builder.cc b/third_party/xla/xla/backends/cpu/codegen/kernel_api_ir_builder.cc index 2ce99c4b93a878..5c5188ce33b5ce 100644 --- a/third_party/xla/xla/backends/cpu/codegen/kernel_api_ir_builder.cc +++ b/third_party/xla/xla/backends/cpu/codegen/kernel_api_ir_builder.cc @@ -147,7 +147,7 @@ llvm::StructType* KernelCallFrameTy(llvm::LLVMContext& ctx) { llvm::PointerType* ptr = llvm::PointerType::getUnqual(ctx); llvm::IntegerType* i64 = llvm::IntegerType::getInt64Ty(ctx); return llvm::StructType::create("XLA_CPU_KernelCallFrame", ptr, ptr, i64, - ptr); + ptr, i64); } llvm::FunctionType* KernelFunctionTy(llvm::LLVMContext& ctx) { @@ -316,6 +316,34 @@ auto KernelApiIrBuilder::EmitKernelPrototype( return EmitKernelPrototype(module, name, arguments, results); } +llvm::Value* KernelApiIrBuilder::EmitGetBatchDim(llvm::IRBuilderBase& builder, + llvm::Value* call_frame) { + llvm::LLVMContext& ctx = builder.getContext(); + llvm::Type* ptr = llvm::PointerType::get(ctx, 0); + llvm::IntegerType* i64 = llvm::IntegerType::getInt64Ty(ctx); + llvm::Value* bdim_gep = + builder.CreateStructGEP(call_frame_ty_, call_frame, 4, "bdim_gep"); + llvm::Value* bdim_value = builder.CreateLoad(i64, bdim_gep, "bdim_value"); + +#if defined(PRINT_BATCHSIZE) + // Print batch size + llvm::Function* function = builder.GetInsertBlock()->getParent(); + llvm::Module* module = function->getParent(); + llvm::FunctionType* printfType = llvm::FunctionType::get( + builder.getInt32Ty(), llvm::PointerType::get(builder.getInt8Ty(), 0), + true); + llvm::Value* funcNameStr = + builder.CreateGlobalStringPtr(function->getName()); + llvm::FunctionCallee printfFunc = + module->getOrInsertFunction("printf", printfType); + llvm::Value* formatStr = + builder.CreateGlobalStringPtr("Function: %s, Batch size is : %d!\n"); + builder.CreateCall(printfFunc, {formatStr, funcNameStr, bdim_value}); +#endif + + return bdim_value; +} + auto KernelApiIrBuilder::EmitKernelPrototype( llvm::Module& module, absl::string_view name, absl::Span arguments, @@ -394,6 +422,8 @@ auto KernelApiIrBuilder::EmitKernelPrototype( ir_results.push_back(std::move(ir_result)); } + EmitGetBatchDim(b, call_frame); + // Return null pointer to signal success as we do not support error handling // in the compiled host kernel. llvm::BasicBlock* return_block = @@ -503,7 +533,8 @@ llvm_ir::IrArray KernelApiIrBuilder::EmitKernelArgument( const llvm::DataLayout& data_layout = llvm_module->getDataLayout(); int64_t pointer_size = data_layout.getTypeStoreSize(builder.getPtrTy()); int64_t byte_size = ShapeUtil::ByteSizeOf(shape, pointer_size); - llvm_ir::SetDereferenceableMetadataForLoad(data, byte_size); + if (!shape.has_dynamic_expr()) + llvm_ir::SetDereferenceableMetadataForLoad(data,byte_size); // All buffers pointers passed to host kernels are expected to be invariant // over the whole program. Note the metadata is attached only to loading diff --git a/third_party/xla/xla/backends/cpu/codegen/kernel_api_ir_builder.h b/third_party/xla/xla/backends/cpu/codegen/kernel_api_ir_builder.h index 58ffc5bc8594e3..08af675fcb6c37 100644 --- a/third_party/xla/xla/backends/cpu/codegen/kernel_api_ir_builder.h +++ b/third_party/xla/xla/backends/cpu/codegen/kernel_api_ir_builder.h @@ -147,9 +147,13 @@ class KernelApiIrBuilder { llvm_ir::IrArray EmitKernelArgument(llvm::IRBuilderBase& builder, llvm::Value* call_frame, int64_t index, const Shape& shape); + llvm::Function* EmitKernelFunction(llvm::Module& module, absl::string_view name); + llvm::Value* EmitGetBatchDim(llvm::IRBuilderBase& builder, + llvm::Value* call_frame); + private: llvm::LLVMContext& context_; diff --git a/third_party/xla/xla/backends/cpu/runtime/kernel.cc b/third_party/xla/xla/backends/cpu/runtime/kernel.cc index ac1a5d7181ec52..668e10f0a0dd09 100644 --- a/third_party/xla/xla/backends/cpu/runtime/kernel.cc +++ b/third_party/xla/xla/backends/cpu/runtime/kernel.cc @@ -64,7 +64,7 @@ template class Kernel::ParallelTask { public: ParallelTask(XLA_CPU_Kernel* kernel, Kernel::ThreadDim thread_dims, - absl::Span args); + size_t batch_size, absl::Span args); // Invokes a host kernel for a given task index. absl::Status operator()(size_t task_index) const; @@ -76,6 +76,7 @@ class Kernel::ParallelTask { XLA_CPU_Kernel* kernel_; XLA_CPU_KernelThreadDim thread_dims_; + size_t batch_size_; absl::InlinedVector args_; size_t num_tasks_; @@ -88,11 +89,13 @@ class Kernel::ParallelTask { template Kernel::ParallelTask::ParallelTask( XLA_CPU_Kernel* kernel, Kernel::ThreadDim thread_dims, + size_t batch_size, absl::Span args) : kernel_(kernel), thread_dims_({thread_dims.x, thread_dims.y, thread_dims.z}), args_(args.begin(), args.end()), num_tasks_(thread_dims_.x * thread_dims_.y * thread_dims_.z), + batch_size_(batch_size), stride_z_(thread_dims_.y * thread_dims_.x), stride_y_(thread_dims_.x) {} @@ -103,7 +106,7 @@ absl::Status Kernel::ParallelTask::operator()( XLA_CPU_KernelThread kernel_thread = Delinearize(task_index); XLA_CPU_KernelCallFrame call_frame = {&thread_dims_, &kernel_thread, - args_.size(), args_.data()}; + args_.size(), args_.data(), batch_size_}; XLA_CPU_KernelError* error = (*kernel_)(&call_frame); @@ -138,12 +141,12 @@ Kernel::Kernel(unsigned arity, XLA_CPU_Kernel* kernel) kernel_(function_->kernel()), arity_(arity) {} -absl::Status Kernel::Launch(const ThreadDim& thread_dims, +absl::Status Kernel::Launch(const ThreadDim& thread_dims, size_t batch_size, absl::Span buffers) const { - return Launch(thread_dims, ConvertBuffersToKernelArgs(buffers)); + return Launch(thread_dims, batch_size, ConvertBuffersToKernelArgs(buffers)); } -absl::Status Kernel::Launch(const ThreadDim& thread_dims, +absl::Status Kernel::Launch(const ThreadDim& thread_dims, size_t batch_size, absl::Span args) const { XLA_CPU_KernelThreadDim kernel_thread_dims = { thread_dims.x, @@ -156,8 +159,9 @@ absl::Status Kernel::Launch(const ThreadDim& thread_dims, for (uint64_t x = 0; x < thread_dims.x; ++x) { XLA_CPU_KernelThread kernel_thread = {x, y, z}; - XLA_CPU_KernelCallFrame call_frame = { - &kernel_thread_dims, &kernel_thread, args.size(), args.data()}; + XLA_CPU_KernelCallFrame call_frame = {&kernel_thread_dims, + &kernel_thread, args.size(), + args.data(), batch_size}; XLA_CPU_KernelError* error = (*kernel_)(&call_frame); @@ -172,20 +176,23 @@ absl::Status Kernel::Launch(const ThreadDim& thread_dims, } tsl::AsyncValueRef Kernel::Launch( - const ThreadDim& thread_dims, absl::Span buffers, + const ThreadDim& thread_dims, size_t batch_size, + absl::Span buffers, const Eigen::ThreadPoolDevice* device) const { - return Launch(thread_dims, ConvertBuffersToKernelArgs(buffers), device); + return Launch(thread_dims, batch_size, ConvertBuffersToKernelArgs(buffers), + device); } tsl::AsyncValueRef Kernel::Launch( - const ThreadDim& thread_dims, absl::Span args, + const ThreadDim& thread_dims, size_t batch_size, + absl::Span args, const Eigen::ThreadPoolDevice* device) const { size_t num_tasks = thread_dims.x * thread_dims.y * thread_dims.z; CHECK_GT(num_tasks, 0) << "Number of tasks must be positive"; // Crash Ok // Short-circuit launch with a single task and run it in the caller thread. if (ABSL_PREDICT_TRUE(num_tasks == 1)) { - absl::Status launched = Launch(thread_dims, args); + absl::Status launched = Launch(thread_dims, batch_size, args); return ABSL_PREDICT_TRUE(launched.ok()) ? OkLaunchEvent() : tsl::MakeErrorAsyncValueRef(std::move(launched)); @@ -197,11 +204,13 @@ tsl::AsyncValueRef Kernel::Launch( std::numeric_limits::max()); if (ABSL_PREDICT_TRUE(thread_dims.y == 1 && thread_dims.z == 1)) { - return Worker::Parallelize(device, num_workers, num_tasks, - ParallelTask(kernel_, thread_dims, args)); + return Worker::Parallelize( + device, num_workers, num_tasks, + ParallelTask(kernel_, thread_dims, batch_size, args)); } else { - return Worker::Parallelize(device, num_workers, num_tasks, - ParallelTask(kernel_, thread_dims, args)); + return Worker::Parallelize( + device, num_workers, num_tasks, + ParallelTask(kernel_, thread_dims, batch_size, args)); } } diff --git a/third_party/xla/xla/backends/cpu/runtime/kernel.h b/third_party/xla/xla/backends/cpu/runtime/kernel.h index 98010905322f67..0d0465af154464 100644 --- a/third_party/xla/xla/backends/cpu/runtime/kernel.h +++ b/third_party/xla/xla/backends/cpu/runtime/kernel.h @@ -73,13 +73,13 @@ class Kernel { // Calls the kernel once in the caller thread for a thread dim (0,0,0). // This is a fast path for small host kernels that have just one thread. - absl::Status CallOnce(absl::Span args) const; + absl::Status CallOnce(absl::Span args, size_t batch_size) const; // Launches the kernel on the current thread by iterating over all threads in // `thread_dims` and calling the kernel function. - absl::Status Launch(const ThreadDim& thread_dims, + absl::Status Launch(const ThreadDim& thread_dims, size_t batch_size, absl::Span buffers) const; - absl::Status Launch(const ThreadDim& thread_dims, + absl::Status Launch(const ThreadDim& thread_dims, size_t batch_size, absl::Span args) const; // Launches the kernel by iterating over all threads in `thread_dims` and @@ -89,10 +89,12 @@ class Kernel { // Async value returned in constructed state and the caller can access it to // get the number of tasks that are expected to be completed. tsl::AsyncValueRef Launch( - const ThreadDim& thread_dims, absl::Span buffers, + const ThreadDim& thread_dims, size_t batch_size, + absl::Span buffers, const Eigen::ThreadPoolDevice* device) const; tsl::AsyncValueRef Launch( - const ThreadDim& thread_dims, absl::Span args, + const ThreadDim& thread_dims, size_t batch_size, + absl::Span args, const Eigen::ThreadPoolDevice* device) const; // For host platform, we assume that a core is a thread, and we can run at @@ -123,12 +125,13 @@ class Kernel { }; inline ABSL_ATTRIBUTE_ALWAYS_INLINE absl::Status Kernel::CallOnce( - absl::Span args) const { + absl::Span args, + size_t batch_size) const { constexpr XLA_CPU_KernelThreadDim kernel_thread_dims = {1, 1, 1}; constexpr XLA_CPU_KernelThread kernel_thread = {1, 1, 1}; XLA_CPU_KernelCallFrame call_frame = {&kernel_thread_dims, &kernel_thread, - args.size(), args.data()}; + args.size(), args.data(), batch_size}; XLA_CPU_KernelError* error = (*kernel_)(&call_frame); diff --git a/third_party/xla/xla/backends/cpu/runtime/kernel_c_api.h b/third_party/xla/xla/backends/cpu/runtime/kernel_c_api.h index cbe0568506385d..cf8016759ed217 100644 --- a/third_party/xla/xla/backends/cpu/runtime/kernel_c_api.h +++ b/third_party/xla/xla/backends/cpu/runtime/kernel_c_api.h @@ -72,6 +72,7 @@ typedef struct XLA_CPU_KernelCallFrame { size_t num_args; const XLA_CPU_KernelArg* args; + size_t batch_size; } XLA_CPU_KernelCallFrame; // Error reporting for host kernels. NULL means success. diff --git a/third_party/xla/xla/backends/cpu/runtime/kernel_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/kernel_thunk.cc index 90c15a09bd677d..7a1c0ec70c449b 100644 --- a/third_party/xla/xla/backends/cpu/runtime/kernel_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/kernel_thunk.cc @@ -223,7 +223,7 @@ KernelThunk::ExecuteInternal( // Use a fast path if kernel called just once. if (ABSL_PREDICT_TRUE(call_once_)) { - TF_RETURN_IF_ERROR(kernel->CallOnce(kernel_args)); + TF_RETURN_IF_ERROR(kernel->CallOnce(kernel_args, params.batch_size)); return OkExecuteEvent(); } @@ -231,10 +231,12 @@ KernelThunk::ExecuteInternal( // by scheduling tasks into it. HostKernel launch completion will // automatically signal KernelThunk execute completion. if (ABSL_PREDICT_TRUE(params.intra_op_threadpool)) { - return kernel->Launch(thread_dim_, kernel_args, params.intra_op_threadpool); + return kernel->Launch(thread_dim_, params.batch_size, kernel_args, + params.intra_op_threadpool); } - TF_RETURN_IF_ERROR(kernel->Launch(thread_dim_, kernel_args)); + TF_RETURN_IF_ERROR( + kernel->Launch(thread_dim_, params.batch_size, kernel_args)); return OkExecuteEvent(); } diff --git a/third_party/xla/xla/backends/cpu/runtime/thunk.h b/third_party/xla/xla/backends/cpu/runtime/thunk.h index d4a56ae88fa55b..5b7042ccb2803f 100644 --- a/third_party/xla/xla/backends/cpu/runtime/thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/thunk.h @@ -259,6 +259,7 @@ class Thunk { TaskRunner* task_runner = nullptr; CollectiveExecuteParams* collective_params = nullptr; CustomCallExecuteParams* custom_call_params = nullptr; + int64_t batch_size = 0; ExecuteSession session = ExecuteSession(ExecuteSession::kMaxWorkers, ExecuteSession::kSplitThreshold); }; diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 33fa90f7e35e9e..961273151ed5e0 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -103,6 +103,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_cpu_use_fusion_emitters(true); opts.set_xla_cpu_use_thunk_runtime(true); opts.set_xla_cpu_use_xnnpack(false); + opts.set_xla_compile_batch_sizes(""); opts.set_xla_cpu_experimental_xnn_graph_fusion_mode( DebugOptions::XNN_GRAPH_FUSION_MODE_DISABLED); opts.set_xla_cpu_parallel_codegen_split_count(32); @@ -1003,6 +1004,15 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "`XNN_GRAPH_FUSION_MODE_DISABLED` - default value, " "`XNN_GRAPH_FUSION_MODE_GREEDY` - greedy extraction of " "XNNPACK-compatible subgraphs starting from root instructions.")); + flag_list->push_back(tsl::Flag( + "xla_compile_batch_sizes", + string_setter_for( + &DebugOptions::set_xla_compile_batch_sizes), + debug_options->xla_compile_batch_sizes(), + "Comma-separated list of batch sizes to use for compilation, " + "use single value or start:end:step format. " + "e.g. 32, 64, 128, 10:100:10, " + "empty to use the nearest power of two.")); flag_list->push_back(tsl::Flag( "xla_cpu_parallel_codegen_split_count", int32_setter_for(&DebugOptions::set_xla_cpu_parallel_codegen_split_count), diff --git a/third_party/xla/xla/executable_run_options.h b/third_party/xla/xla/executable_run_options.h index b377e91670efb9..211932ffcbba98 100644 --- a/third_party/xla/xla/executable_run_options.h +++ b/third_party/xla/xla/executable_run_options.h @@ -195,6 +195,13 @@ class ExecutableRunOptions { return *this; } + ExecutableRunOptions& set_batch_size(int64_t batch_size) { + batch_size_ = batch_size; + return *this; + } + + int64_t batch_size() const { return batch_size_; } + int32_t launch_id() const { return launch_id_; } ExecutableRunOptions& set_run_id(RunId id); @@ -265,6 +272,7 @@ class ExecutableRunOptions { ExecutionProfile* execution_profile_ = nullptr; int rng_seed_ = 0; int32_t launch_id_ = 0; + int64_t batch_size_ = 0; stream_executor::Stream* device_to_host_stream_ = nullptr; stream_executor::Stream* host_to_device_stream_ = nullptr; ThenExecuteFunction* then_execute_function_ = nullptr; diff --git a/third_party/xla/xla/hlo/builder/lib/approx_topk.cc b/third_party/xla/xla/hlo/builder/lib/approx_topk.cc index 5b0f223302f4ed..9cb5fa5eefa2b8 100644 --- a/third_party/xla/xla/hlo/builder/lib/approx_topk.cc +++ b/third_party/xla/xla/hlo/builder/lib/approx_topk.cc @@ -131,10 +131,10 @@ XlaOp AggregateToTopKBuilder(XlaBuilder* builder, reduction_computation, {reduction_dim}); Shape op_shape = operands_shapes[0]; op_shape.set_dimensions(reduction_dim, 1); - auto top1_vals = - Reshape(GetTupleElement(val_args, 0), op_shape.dimensions()); - auto top1_args = - Reshape(GetTupleElement(val_args, 1), op_shape.dimensions()); + auto top1_vals = Reshape(GetTupleElement(val_args, 0), + op_shape.dimensions(), op_shape.expressions()); + auto top1_args = Reshape(GetTupleElement(val_args, 1), + op_shape.dimensions(), op_shape.expressions()); return Tuple(builder, {top1_vals, top1_args}); } diff --git a/third_party/xla/xla/hlo/builder/lib/arithmetic.cc b/third_party/xla/xla/hlo/builder/lib/arithmetic.cc index 84908846dae705..7bb8694351ce92 100644 --- a/third_party/xla/xla/hlo/builder/lib/arithmetic.cc +++ b/third_party/xla/xla/hlo/builder/lib/arithmetic.cc @@ -154,8 +154,9 @@ XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min) { int64_t dimension_size = input_shape.dimensions(axis); auto index_type = dimension_size <= INT32_MAX ? S32 : output_type; XlaOp index_init_value = Zero(builder, index_type); - auto iota_shape = - ShapeUtil::MakeShape(index_type, input_shape.dimensions()); + auto iota_shape = ShapeUtil::MakeShape(index_type, input_shape.dimensions(), + input_shape.expressions()); + XlaOp iota = Iota(builder, iota_shape, axis); XlaComputation reducer = CreateMinMaxComputation( diff --git a/third_party/xla/xla/hlo/builder/lib/broadcast.cc b/third_party/xla/xla/hlo/builder/lib/broadcast.cc index 1baec363b53a35..938b9e1c4df118 100644 --- a/third_party/xla/xla/hlo/builder/lib/broadcast.cc +++ b/third_party/xla/xla/hlo/builder/lib/broadcast.cc @@ -31,8 +31,9 @@ limitations under the License. namespace xla { -absl::StatusOr BroadcastTo(XlaOp input, - absl::Span output_dims) { +absl::StatusOr BroadcastTo( + XlaOp input, absl::Span output_dims, + absl::Span output_exprs) { XlaBuilder* builder = input.builder(); TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); absl::Span input_dims = input_shape.dimensions(); @@ -79,15 +80,37 @@ absl::StatusOr BroadcastTo(XlaOp input, } TF_RET_CHECK(input_it == input_dims.rend()); + absl::Span input_exprs = input_shape.expressions(); + std::vector broadcast_exprs; + auto input_et = input_exprs.rbegin(); + for (auto output_et = output_exprs.rbegin(); output_et != output_exprs.rend(); + ++output_et) { + if (input_et != input_exprs.rend()) { + if (*(*output_et) == *(*input_et) || + (*input_et)->is_constant() && (*input_et)->get_val() == 1) { + broadcast_exprs.push_back(*output_et); + } else if (!(*(*output_et) == *(*input_et))) { + broadcast_exprs.push_back(*input_et); + broadcast_exprs.push_back((**output_et / **input_et)->s()); + } + ++input_et; + } else { + broadcast_exprs.push_back(*output_et); + } + } + absl::c_reverse(broadcast_dims); int broadcast_shape_size = broadcast_shape.size(); for (int64_t& broadcast_dim : broadcast_dims) { broadcast_dim = broadcast_shape_size - broadcast_dim - 1; } absl::c_reverse(broadcast_shape); - XlaOp output = BroadcastInDim(input, broadcast_shape, broadcast_dims); + absl::c_reverse(broadcast_exprs); + + XlaOp output = + BroadcastInDim(input, broadcast_shape, broadcast_dims, broadcast_exprs); if (broadcast_shape != output_dims) { - output = Reshape(output, output_dims); + output = Reshape(output, output_dims, output_exprs); } return output; } diff --git a/third_party/xla/xla/hlo/builder/lib/broadcast.h b/third_party/xla/xla/hlo/builder/lib/broadcast.h index 86cf39f64ddc82..b4ccc51a40d4fd 100644 --- a/third_party/xla/xla/hlo/builder/lib/broadcast.h +++ b/third_party/xla/xla/hlo/builder/lib/broadcast.h @@ -27,8 +27,9 @@ namespace xla { // Broadcasts 'input' up to shape 'output_dims', using TensorFlow broadcasting // rules. Supports broadcasting a dimension of size x to size x*y, i.e., tiling. -absl::StatusOr BroadcastTo(XlaOp input, - absl::Span output_dims); +absl::StatusOr BroadcastTo( + XlaOp input, absl::Span output_dims, + absl::Span output_exprs = {}); } // namespace xla diff --git a/third_party/xla/xla/hlo/builder/lib/matrix.cc b/third_party/xla/xla/hlo/builder/lib/matrix.cc index e9fe29ea83ee6b..3bd051eb9f582f 100644 --- a/third_party/xla/xla/hlo/builder/lib/matrix.cc +++ b/third_party/xla/xla/hlo/builder/lib/matrix.cc @@ -209,7 +209,8 @@ XlaOp SetMatrixDiagonal(XlaOp matrix, XlaOp diag, int k) { } return Select(GetDiagonalMask(matrix, k), - BroadcastInDim(diag, shape.dimensions(), broadcast_dims), + BroadcastInDim(diag, shape.dimensions(), broadcast_dims, + shape.expressions()), matrix); }); } @@ -341,18 +342,22 @@ xla::XlaOp EinsumInverseDiagonal(XlaOp x, absl::Span config) { } TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); std::vector broadcast_sizes; + std::vector broadcast_exprs; int64_t x_dim = 0; for (auto label = config.begin(); label != config.end(); ++label) { auto first_label = absl::c_find(config, *label); if (first_label == label) { broadcast_sizes.push_back(x_shape.dimensions(x_dim)); + broadcast_exprs.push_back(x_shape.expressions(x_dim)); ++x_dim; } else { broadcast_sizes.push_back( broadcast_sizes[first_label - config.begin()]); + broadcast_exprs.push_back( + broadcast_exprs[first_label - config.begin()]); } } - x = BroadcastInDim(x, broadcast_sizes, labels->at(2)); + x = BroadcastInDim(x, broadcast_sizes, labels->at(2), broadcast_exprs); return EinsumDiagonalMask(x, config); }); } @@ -568,16 +573,20 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span x_config, int64_t dot_dim = 0; std::vector new_dims; + std::vector new_exprs; new_dims.reserve(output_rank); + new_exprs.reserve(output_rank); TF_ASSIGN_OR_RETURN(Shape dot_shape, builder->GetShape(dot)); for (auto d : output_config) { if (is_output_only(d)) { new_dims.push_back(1); + new_exprs.push_back(DynExpr::one); } else { new_dims.push_back(dot_shape.dimensions(dot_dim)); + new_exprs.push_back(dot_shape.expressions(dot_dim)); } } - return Reshape(dot, new_dims); + return Reshape(dot, new_dims, new_exprs); }); } diff --git a/third_party/xla/xla/hlo/builder/lib/prng.cc b/third_party/xla/xla/hlo/builder/lib/prng.cc index 661981f19c17ea..a18fa3cfe54e9c 100644 --- a/third_party/xla/xla/hlo/builder/lib/prng.cc +++ b/third_party/xla/xla/hlo/builder/lib/prng.cc @@ -41,8 +41,9 @@ namespace xla { xla::XlaOp ConcatScalars(xla::XlaBuilder* builder, absl::Span scalars) { std::vector vectors; - absl::c_transform(scalars, std::back_inserter(vectors), - [](xla::XlaOp x) { return xla::Reshape(x, {1}); }); + absl::c_transform(scalars, std::back_inserter(vectors), [](xla::XlaOp x) { + return xla::Reshape(x, {1}, {xla::DynExpr::one}); + }); return ConcatInDim(builder, vectors, 0); } @@ -154,7 +155,8 @@ std::pair GetThreeFryInputsAndUpdatedState( XlaBuilder* builder = initial_state.builder(); auto u64_shape = ShapeUtil::MakeShape(U64, shape.dimensions()); // initial_state is an R1, so reshape it to a scalar. - auto input_u64 = Broadcast(Reshape(initial_state, {}), shape.dimensions()); + auto input_u64 = Broadcast(Reshape(initial_state, {}), shape.dimensions(), + shape.expressions()); int64_t trailing_dims_product = 1; for (int64_t i = shape.dimensions().size() - 1; i >= 0; --i) { if (shape.dimensions(i) < 2) { @@ -245,8 +247,12 @@ XlaOp CombineShapePair(absl::Span pair, original_shape.dimensions(shape_pair.split_dim); std::vector reshape_dims(original_shape.dimensions().begin(), original_shape.dimensions().end()); + std::vector reshape_exprs(original_shape.expressions().begin(), + original_shape.expressions().end()); reshape_dims[shape_pair.split_dim] = RoundUpTo(pre_split_size, 2); - result = Reshape(result, reshape_dims); + reshape_exprs[shape_pair.split_dim] = + DynExpr::_(RoundUpTo(pre_split_size, 2)); + result = Reshape(result, reshape_dims, reshape_exprs); if (reshape_dims[shape_pair.split_dim] != pre_split_size) { result = Slice(result, std::vector(original_shape.dimensions().size(), 0), @@ -453,11 +459,11 @@ RngOutput PhiloxRngBit32(XlaOp op_key, XlaOp initial_state, } XlaOp numbers = ConcatInDim(builder, {bits[0], bits[1], bits[2], bits[3]}, /*dimension=*/1); - numbers = Reshape(numbers, {bits_len * 4}); + numbers = Reshape(numbers, {bits_len * 4}, {}); numbers = Slice(numbers, /*start_indices=*/{0}, /*limit_indices=*/{num_elems}, /*strides=*/{1}); - return {Reshape(numbers, shape.dimensions()), new_state}; + return {Reshape(numbers, shape.dimensions(), shape.expressions()), new_state}; } // Generates an array of primitive type U16 with the given shape containing @@ -507,7 +513,7 @@ RngOutput PhiloxRngBit64(XlaOp op_key, XlaOp initial_state, numbers = Slice(numbers, /*start_indices=*/{0}, /*limit_indices=*/{num_elems}, /*strides=*/{1}); - return {Reshape(numbers, shape.dimensions()), new_state}; + return {Reshape(numbers, shape.dimensions(), shape.expressions()), new_state}; } XlaOp ConvertRandomBitsToUniformFloatingPoint(XlaOp bits, XlaOp minval, diff --git a/third_party/xla/xla/hlo/builder/lib/slicing.cc b/third_party/xla/xla/hlo/builder/lib/slicing.cc index ae0f6b987497a1..479d743a3fbc0c 100644 --- a/third_party/xla/xla/hlo/builder/lib/slicing.cc +++ b/third_party/xla/xla/hlo/builder/lib/slicing.cc @@ -168,13 +168,16 @@ XlaOp TorchGather(XlaOp input, XlaOp index, int64_t dim, bool sparse) { std::vector index_broadcast_dims; std::vector input_broadcast_dims; std::vector sizes; + std::vector expressions; sizes.reserve(index_shape.dimensions().size()); + expressions.reserve(index_shape.expressions().size()); for (int64_t i = 0; i < index_shape.dimensions().size(); ++i) { if (i < dim) { input_broadcast_dims.push_back(i); index_broadcast_dims.push_back(i); } else if (i == dim) { sizes.push_back(input_shape.dimensions(i)); + expressions.push_back(input_shape.expressions(i)); input_broadcast_dims.push_back(i); index_broadcast_dims.push_back(i + 1); } else { @@ -182,15 +185,18 @@ XlaOp TorchGather(XlaOp input, XlaOp index, int64_t dim, bool sparse) { index_broadcast_dims.push_back(i + 1); } sizes.push_back(index_shape.dimensions(i)); + expressions.push_back(index_shape.expressions(i)); } - auto mask = Eq( - BroadcastInDim(index, sizes, index_broadcast_dims), - Iota(builder, ShapeUtil::MakeShape(index_shape.element_type(), sizes), - dim)); + auto mask = + Eq(BroadcastInDim(index, sizes, index_broadcast_dims, expressions), + Iota(builder, + ShapeUtil::MakeShape(index_shape.element_type(), sizes, + expressions), + dim)); auto masked_input = Select( - mask, BroadcastInDim(input, sizes, input_broadcast_dims), - Zeros(builder, - ShapeUtil::MakeShape(input_shape.element_type(), sizes))); + mask, BroadcastInDim(input, sizes, input_broadcast_dims, expressions), + Zeros(builder, ShapeUtil::MakeShape(input_shape.element_type(), sizes, + expressions))); return Reduce(masked_input, Zero(builder, input_shape.element_type()), CreateScalarIdentityWithZeroComputation( input_shape.element_type(), builder), @@ -203,7 +209,8 @@ XlaOp TorchGather(XlaOp input, XlaOp index, int64_t dim, bool sparse) { to_concat.reserve(input_shape.dimensions().size()); for (int64_t i = 0; i < input_shape.dimensions().size(); ++i) { if (i == dim) { - to_concat.push_back(Reshape(index, index_shape.dimensions())); + to_concat.push_back(Reshape(index, index_shape.dimensions(), + index_shape.expressions())); } else { to_concat.push_back(Iota(builder, index_shape, i)); } @@ -229,27 +236,33 @@ XlaOp TorchScatterDense(XlaOp input, XlaOp index, XlaOp src, int64_t dim, TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); std::vector index_broadcast_dims; std::vector sizes; + std::vector expressions; const auto rank = index_shape.dimensions().size(); sizes.reserve(rank + 1); + expressions.reserve(rank + 1); for (int64_t i = 0; i < index_shape.dimensions().size(); ++i) { if (i < dim) { index_broadcast_dims.push_back(i); } else { if (i == dim) { sizes.push_back(input_shape.dimensions(i)); + expressions.push_back(input_shape.expressions(i)); } index_broadcast_dims.push_back(i + 1); } sizes.push_back(index_shape.dimensions(i)); + expressions.push_back(index_shape.expressions(i)); } auto mask = - Eq(BroadcastInDim(index, sizes, index_broadcast_dims), + Eq(BroadcastInDim(index, sizes, index_broadcast_dims, expressions), Iota(builder, - ShapeUtil::MakeShape(index_shape.element_type(), sizes), dim)); - auto masked_src = - Select(mask, BroadcastInDim(src, sizes, index_broadcast_dims), - Zeros(builder, - ShapeUtil::MakeShape(input_shape.element_type(), sizes))); + ShapeUtil::MakeShape(index_shape.element_type(), sizes, + expressions), + dim)); + auto masked_src = Select( + mask, BroadcastInDim(src, sizes, index_broadcast_dims, expressions), + Zeros(builder, ShapeUtil::MakeShape(input_shape.element_type(), sizes, + expressions))); return combiner( input, @@ -287,7 +300,8 @@ XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64_t dim, for (int64_t batch_dim = 0; batch_dim < batch_dims; ++batch_dim) { to_concat.push_back(Iota(builder, iota_shape, batch_dim)); } - to_concat.push_back(Reshape(index, index_shape.dimensions())); + to_concat.push_back( + Reshape(index, index_shape.dimensions(), index_shape.expressions())); index = ConcatInDim(builder, to_concat, gather_dnums.index_vector_dim()); } for (int64_t i = 0; i < input_shape.dimensions().size(); ++i) { diff --git a/third_party/xla/xla/hlo/builder/lib/svd.cc b/third_party/xla/xla/hlo/builder/lib/svd.cc index 561c107ba8085a..9254085e69a85e 100644 --- a/third_party/xla/xla/hlo/builder/lib/svd.cc +++ b/third_party/xla/xla/hlo/builder/lib/svd.cc @@ -152,9 +152,9 @@ absl::StatusOr HouseRow( auto beta = Div(ScalarLike(v_0j, 2.0), (Square(Div(sigma, v_0j, broadcast_dims)) + one)); - v = Select( - BroadcastInDim(Lt(sigma, eps), x_shape.dimensions(), broadcast_dims), v, - v / v_0j); + v = Select(BroadcastInDim(Lt(sigma, eps), x_shape.dimensions(), + broadcast_dims, x_shape.expressions()), + v, v / v_0j); v = Select(Eq(idx, j), zeros + one, v); beta = Select(Lt(Add(sigma, ZerosLike(beta), broadcast_dims), eps), @@ -219,9 +219,9 @@ absl::StatusOr HouseCol( auto beta = Div(ScalarLike(v_0i, 2.0), (Square(Div(sigma, v_0i, broadcast_dims)) + one)); - v = Select( - BroadcastInDim(Lt(sigma, eps), x_shape.dimensions(), broadcast_dims), v, - v / v_0i); + v = Select(BroadcastInDim(Lt(sigma, eps), x_shape.dimensions(), + broadcast_dims, x_shape.expressions()), + v, v / v_0i); v = Select(Eq(idx, i), zeros + one, v); beta = Select(Lt(Add(sigma, ZerosLike(beta), broadcast_dims), eps), @@ -582,11 +582,11 @@ absl::StatusOr ComputeToleranceComparison(XlaOp w, XlaOp epsilon) { diag = Select(Lt(diag, ZerosLike(diag)), -diag, diag); std::vector broadcasted_dims(num_dims - 1); std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0); - auto broadcast_to_rows = - BroadcastInDim(diag, shape.dimensions(), broadcasted_dims); + auto broadcast_to_rows = BroadcastInDim( + diag, shape.dimensions(), broadcasted_dims, shape.expressions()); broadcasted_dims.back() = num_dims - 1; - auto broadcast_to_columns = - BroadcastInDim(diag, shape.dimensions(), broadcasted_dims); + auto broadcast_to_columns = BroadcastInDim( + diag, shape.dimensions(), broadcasted_dims, shape.expressions()); // Compute tolerance = w_{i,i} * w_{j,j} * epsilon^2 // Use at least F32 precision to avoid precision issues with small denormal. XlaOp tolerance; @@ -745,6 +745,7 @@ absl::StatusOr SortBySingularValuesAndPostProcessing( TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(result.d)); const int64_t num_dims = shape.dimensions().size(); auto dimensions = shape.dimensions(); + auto expressions = shape.expressions(); const int64_t m = ShapeUtil::GetDimension(shape, -2); const int64_t n = ShapeUtil::GetDimension(shape, -1); @@ -763,7 +764,7 @@ absl::StatusOr SortBySingularValuesAndPostProcessing( d = Select(Ge(d, zeros), d, -d); result.v = Mul(result.v, sign, broadcast_dims); - d = BroadcastInDim(d, dimensions, broadcast_dims); + d = BroadcastInDim(d, dimensions, broadcast_dims, expressions); // As m >= n, only first n column vectors need to be permuted, and the rest of // m - n vectors are appended after the sorting is done. diff --git a/third_party/xla/xla/hlo/builder/xla_builder.cc b/third_party/xla/xla/hlo/builder/xla_builder.cc index 93d7782de50e03..de4a2b10a3e49c 100644 --- a/third_party/xla/xla/hlo/builder/xla_builder.cc +++ b/third_party/xla/xla/hlo/builder/xla_builder.cc @@ -1017,22 +1017,25 @@ absl::StatusOr XlaBuilder::AddBroadcastSequence( Shape broadcast_shape = ShapeUtil::ChangeElementType(output_shape, operand_shape->element_type()); - // Do explicit broadcast for scalar. - if (ShapeUtil::IsScalar(*operand_shape)) { - return InDimBroadcast(ShapeUtil::MakeStaticShape(broadcast_shape), operand, - {}); - } + // Do explicit broadcast for scalar. + if (ShapeUtil::IsScalar(*operand_shape)) { + return InDimBroadcast(ShapeUtil::MakeStaticShape(broadcast_shape), + operand, {}); + } // Do explicit broadcast for degenerate broadcast. std::vector broadcast_dimensions; std::vector reshaped_dimensions; std::vector reshaped_dynamic_dimensions; + std::vector reshaped_expressions; for (int i = 0; i < operand_shape->dimensions().size(); i++) { if (operand_shape->dimensions(i) == output_shape.dimensions(i)) { broadcast_dimensions.push_back(i); reshaped_dimensions.push_back(operand_shape->dimensions(i)); reshaped_dynamic_dimensions.push_back( operand_shape->is_dynamic_dimension(i)); + reshaped_expressions.push_back( + operand_shape->expressions(i)); } else { TF_RET_CHECK(operand_shape->dimensions(i) == 1 && operand_shape->is_static_dimension(i)) @@ -1046,7 +1049,7 @@ absl::StatusOr XlaBuilder::AddBroadcastSequence( Shape reshaped_shape = ShapeUtil::MakeShape(operand_shape->element_type(), reshaped_dimensions, - reshaped_dynamic_dimensions); + reshaped_dynamic_dimensions, reshaped_expressions); // Eliminate the size one dimensions. // The added reshape reduces the rank of the tensor. Hence we cannot directly @@ -1094,15 +1097,20 @@ absl::StatusOr BroadcastToTargetRank( return origin; } - // Update target_size with origin sizes using broadcast_dimensions + // Update target_size and target_exp with origin sizes and expressions using + // broadcast_dimensions absl::Span target_dimensions = target_shape.dimensions(); + absl::Span target_expressions = target_shape.expressions(); std::vector target_size{target_dimensions.begin(), target_dimensions.end()}; + std::vector target_exp{target_expressions.begin(), + target_expressions.end()}; for (int64_t origin_dim = 0; origin_dim < origin_rank; origin_dim++) { int64_t target_dim = broadcast_dimensions[origin_dim]; target_size[target_dim] = origin_shape.dimensions(origin_dim); + target_exp[target_dim] = origin_shape.expressions(origin_dim); } - return BroadcastInDim(origin, target_size, broadcast_dimensions); + return BroadcastInDim(origin, target_size, broadcast_dimensions, target_exp); } // Extract the `num_dims` counts of dimension sizes from the `op`. First, @@ -1120,7 +1128,7 @@ absl::StatusOr> ExtractDimensionSizesAndPadOnesToLeft( ? ConstantR1( /*builder=*/builder, /*values=*/{static_cast(op_shape->dimensions(i))}) - : Reshape(GetDimensionSize(op, i), {1})); + : Reshape(GetDimensionSize(op, i), {1}, {xla::DynExpr::one})); } return op_dims; } @@ -1142,7 +1150,8 @@ absl::StatusOr BroadcastScalarToOutputShapeWithUnbounded( ? ConstantR1( /*builder=*/builder, /*values=*/{static_cast(output_shape.dimensions(i))}) - : Reshape(GetDimensionSize(output, i), {1}); + : Reshape(GetDimensionSize(output, i), {1}, + {xla::DynExpr::one}); } return MhloDynamicBroadcastInDim( scalar, /*output_dimensions=*/ConcatInDim(builder, output_sizes, 0), {}, @@ -1525,12 +1534,13 @@ XlaOp XlaBuilder::Parameter( } XlaOp XlaBuilder::Broadcast(XlaOp operand, - absl::Span broadcast_sizes) { + absl::Span broadcast_sizes, + absl::Span broadcast_exprs) { return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); - TF_ASSIGN_OR_RETURN( - const Shape& shape, - ShapeInference::InferBroadcastShape(*operand_shape, broadcast_sizes)); + TF_ASSIGN_OR_RETURN(const Shape& shape, + ShapeInference::InferBroadcastShape( + *operand_shape, broadcast_sizes, broadcast_exprs)); // The client-level broadcast op just appends dimensions on the left (adds // lowest numbered dimensions). The HLO broadcast instruction is more @@ -1550,14 +1560,15 @@ XlaOp XlaBuilder::Broadcast(XlaOp operand, XlaOp XlaBuilder::BroadcastInDim( XlaOp operand, absl::Span out_dim_size, - absl::Span broadcast_dimensions) { + absl::Span broadcast_dimensions, + absl::Span out_dim_exp) { return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); // Output shape, in the case of degenerate broadcast, the out_dim_size is // not necessarily the same as the dimension sizes of the output shape. - TF_ASSIGN_OR_RETURN(auto output_shape, - ShapeUtil::MakeValidatedShape( - operand_shape->element_type(), out_dim_size)); + TF_ASSIGN_OR_RETURN(auto output_shape, ShapeUtil::MakeValidatedShape( + operand_shape->element_type(), + out_dim_size, out_dim_exp)); TF_RET_CHECK(!output_shape.is_unbounded_dynamic()) << "BroadcastInDim output must shape be static or bounded dynamic " << ShapeUtil::HumanString(output_shape); @@ -1584,6 +1595,18 @@ XlaOp XlaBuilder::BroadcastInDim( .status()); std::vector in_dim_size(out_dim_size.begin(), out_dim_size.end()); std::vector in_dim_dynamic(out_dim_size.size(), false); + std::vector in_expressions(out_dim_exp.begin(), + out_dim_exp.end()); + + // If out_dim_exp is empty just make expressions out of the static + // dimensions. + if (out_dim_exp.empty()) { + in_expressions.reserve(out_dim_size.size()); + std::transform(out_dim_size.begin(), out_dim_size.end(), + std::back_inserter(in_expressions), + [](int d) { return DynExpr::_(d); }); + } + for (int i = 0; i < broadcast_rank; i++) { in_dim_size[broadcast_dimensions[i]] = (operand_shape->is_unbounded_dynamic_dimension(i)) @@ -1591,9 +1614,12 @@ XlaOp XlaBuilder::BroadcastInDim( : operand_shape->dimensions(i); in_dim_dynamic[broadcast_dimensions[i]] = operand_shape->is_bounded_dynamic_dimension(i); + in_expressions[broadcast_dimensions[i]] = + operand_shape->expressions(i); } - const auto& in_dim_shape = ShapeUtil::MakeShape( - operand_shape->element_type(), in_dim_size, in_dim_dynamic); + const auto& in_dim_shape = + ShapeUtil::MakeShape(operand_shape->element_type(), in_dim_size, + in_dim_dynamic, in_expressions); TF_ASSIGN_OR_RETURN( XlaOp in_dim_broadcast, InDimBroadcast(in_dim_shape, operand, broadcast_dimensions)); @@ -1637,6 +1663,21 @@ XlaOp XlaBuilder::Slice(XlaOp operand, absl::Span start_indices, }); } +XlaOp XlaBuilder::Slice(XlaOp operand, absl::Span start_indices, + absl::Span limit_indices, + absl::Span start_exprs, + absl::Span limit_exprs, + absl::Span strides) { + return ReportErrorOrReturn([&]() -> absl::StatusOr { + TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); + TF_ASSIGN_OR_RETURN( + Shape shape, ShapeInference::InferSliceShape( + *operand_shape, start_indices, limit_indices, strides, + start_exprs, limit_exprs)); + return SliceInternal(shape, operand, start_indices, limit_indices, strides); + }); +} + absl::StatusOr XlaBuilder::SliceInternal( const Shape& shape, XlaOp operand, absl::Span start_indices, absl::Span limit_indices, @@ -1668,9 +1709,33 @@ XlaOp XlaBuilder::SliceInDim(XlaOp operand, int64_t start_index, }); } +XlaOp XlaBuilder::SliceInDim(XlaOp operand, int64_t start_index, + int64_t limit_index, DynExpr* start_expr, + DynExpr* limit_expr, int64_t stride, + int64_t dimno) { + return ReportErrorOrReturn([&]() -> absl::StatusOr { + TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand)); + std::vector starts(shape->dimensions().size(), 0); + std::vector limits(shape->dimensions().begin(), + shape->dimensions().end()); + std::vector start_exprs(shape->dimensions().size(), + DynExpr::zero); + std::vector limit_exprs(shape->expressions().begin(), + shape->expressions().end()); + std::vector strides(shape->dimensions().size(), 1); + starts[dimno] = start_index; + limits[dimno] = limit_index; + start_exprs[dimno] = start_expr; + limit_exprs[dimno] = limit_expr; + strides[dimno] = stride; + return Slice(operand, starts, limits, start_exprs, limit_exprs, strides); + }); +} + XlaOp XlaBuilder::DynamicSlice(XlaOp operand, absl::Span start_indices, - absl::Span slice_sizes) { + absl::Span slice_sizes, + absl::Span slice_exprs) { return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); std::vector start_indices_shape_ptrs; @@ -1679,9 +1744,9 @@ XlaOp XlaBuilder::DynamicSlice(XlaOp operand, absl::c_transform(start_indices_shapes, std::back_inserter(start_indices_shape_ptrs), [](const Shape& shape) { return &shape; }); - TF_ASSIGN_OR_RETURN(Shape shape, - ShapeInference::InferDynamicSliceShape( - *operand_shape, start_indices_shapes, slice_sizes)); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDynamicSliceShape( + *operand_shape, start_indices_shapes, + slice_sizes, slice_exprs)); return DynamicSliceInternal(shape, operand, start_indices, slice_sizes); }); } @@ -1698,6 +1763,7 @@ absl::StatusOr XlaBuilder::DynamicSliceInternal( std::vector operands = {operand}; operands.insert(operands.end(), start_indices.begin(), start_indices.end()); + return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, operands); } @@ -1794,9 +1860,22 @@ XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span dimensions, int64_t inferred_dimension) { return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); - TF_ASSIGN_OR_RETURN(const Shape shape, - ShapeInference::InferReshapeShape( - *operand_shape, dimensions, inferred_dimension)); + TF_ASSIGN_OR_RETURN(const Shape shape, ShapeInference::InferReshapeShape( + *operand_shape, dimensions, + inferred_dimension, {})); + return ReshapeInternal(shape, operand, inferred_dimension); + }); +} + +XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span dimensions, + absl::Span expressions, + int64_t inferred_dimension) { + return ReportErrorOrReturn([&]() -> absl::StatusOr { + TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); + TF_ASSIGN_OR_RETURN( + const Shape shape, + ShapeInference::InferReshapeShape(*operand_shape, dimensions, + inferred_dimension, expressions)); return ReshapeInternal(shape, operand, inferred_dimension); }); } @@ -1811,7 +1890,8 @@ XlaOp XlaBuilder::Reshape(const Shape& shape, XlaOp operand, XlaOp XlaBuilder::DynamicReshape(XlaOp operand, absl::Span dim_sizes, absl::Span new_size_bounds, - const std::vector& dims_are_dynamic) { + const std::vector& dims_are_dynamic, + absl::Span expressions) { return ReportErrorOrReturn([&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); std::vector dim_size_shape_ptrs; @@ -1820,10 +1900,11 @@ XlaOp XlaBuilder::DynamicReshape(XlaOp operand, absl::c_transform(dim_size_shapes, std::back_inserter(dim_size_shape_ptrs), [](const Shape& shape) { return &shape; }); - TF_ASSIGN_OR_RETURN(const Shape shape, - ShapeInference::InferDynamicReshapeShape( - *operand_shape, dim_size_shape_ptrs, - new_size_bounds, dims_are_dynamic)); + TF_ASSIGN_OR_RETURN( + const Shape shape, + ShapeInference::InferDynamicReshapeShape( + *operand_shape, dim_size_shape_ptrs, new_size_bounds, + dims_are_dynamic, expressions)); TF_RETURN_IF_ERROR(first_error_); std::vector operands; operands.reserve(1 + dim_sizes.size()); @@ -1865,17 +1946,21 @@ XlaOp XlaBuilder::Collapse(XlaOp operand, VLOG(3) << "dims to collapse: " << absl::StrJoin(dimensions, ","); std::vector new_sizes; + std::vector new_exprs; for (int i = 0; i < original_shape->dimensions().size(); ++i) { if (i <= dimensions.front() || i > dimensions.back()) { new_sizes.push_back(original_shape->dimensions(i)); + new_exprs.push_back(original_shape->expressions(i)); } else { new_sizes.back() *= original_shape->dimensions(i); + new_exprs.back() = + *(new_exprs.back()) * *(original_shape->expressions(i)); } } VLOG(3) << "new sizes: [" << absl::StrJoin(new_sizes, ",") << "]"; - return Reshape(operand, new_sizes); + return Reshape(operand, new_sizes, new_exprs); }); } @@ -3912,12 +3997,14 @@ XlaOp XlaBuilder::AllToAllArray( return all_to_all; } DimensionVector sizes; + std::vector expressions; const bool is_unbounded = operand_shape->is_unbounded_dynamic(); std::vector dynamic_sizes; auto GetR1DimensionSizeOrConstant = [&](XlaOp operand, int64_t dimension) -> XlaOp { if (operand_shape->is_unbounded_dynamic_dimension(dimension)) { - return Reshape(GetDimensionSize(operand, dimension), {1}); + return Reshape(GetDimensionSize(operand, dimension), {1}, + {DynExpr::one}); } return ConstantR1( this, {static_cast(operand_shape->dimensions(dimension))}); @@ -3927,15 +4014,19 @@ XlaOp XlaBuilder::AllToAllArray( for (int64_t i = 0; i < operand_shape->dimensions().size(); ++i) { if (i != split_dimension) { sizes.push_back(operand_shape->dimensions(i)); + expressions.push_back(operand_shape->expressions(i)); if (is_unbounded) { dynamic_sizes.push_back(GetR1DimensionSizeOrConstant(operand, i)); } continue; } sizes.push_back(split_count); + expressions.push_back(DynExpr::_(split_count)); sizes.push_back(operand_shape->is_unbounded_dynamic_dimension(i) ? Shape::kUnboundedSize : operand_shape->dimensions(i) / split_count); + expressions.push_back( + (*operand_shape->expressions(i) / split_count)->s()); if (is_unbounded) { dynamic_sizes.push_back(r1_split_count); @@ -3955,11 +4046,11 @@ XlaOp XlaBuilder::AllToAllArray( TF_ASSIGN_OR_RETURN( const Shape shape, ShapeUtil::MakeValidatedShape(all_to_all_shape.element_type(), sizes, - dynamic_dimensions)); + dynamic_dimensions, expressions)); all_to_all = MhloDynamicReshape(all_to_all, ConcatInDim(dynamic_sizes, 0), shape); } else { - all_to_all = Reshape(all_to_all, sizes); + all_to_all = Reshape(all_to_all, sizes, expressions); } std::vector permutation; @@ -4986,16 +5077,18 @@ XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal) { return builder->ConstantLiteral(literal); } -XlaOp Broadcast(const XlaOp operand, - absl::Span broadcast_sizes) { - return operand.builder()->Broadcast(operand, broadcast_sizes); +XlaOp Broadcast(const XlaOp operand, absl::Span broadcast_sizes, + absl::Span broadcast_exprs) { + return operand.builder()->Broadcast(operand, broadcast_sizes, + broadcast_exprs); } XlaOp BroadcastInDim(const XlaOp operand, absl::Span out_dim_size, - absl::Span broadcast_dimensions) { + absl::Span broadcast_dimensions, + absl::Span out_dim_exp) { return operand.builder()->BroadcastInDim(operand, out_dim_size, - broadcast_dimensions); + broadcast_dimensions, out_dim_exp); } XlaOp MhloDynamicReshape(const XlaOp operand, const XlaOp output_shape, @@ -5030,21 +5123,29 @@ XlaOp Reshape(const XlaOp operand, absl::Span dimensions) { return operand.builder()->Reshape(operand, dimensions); } +XlaOp Reshape(const XlaOp operand, absl::Span dimensions, + absl::Span expressions) { + return operand.builder()->Reshape(operand, dimensions, expressions); +} + XlaOp Reshape(const Shape& shape, XlaOp operand) { return operand.builder()->Reshape(shape, operand); } XlaOp DynamicReshape(XlaOp operand, absl::Span dim_sizes, absl::Span new_size_bounds, - const std::vector& dims_are_dynamic) { + const std::vector& dims_are_dynamic, + absl::Span expressions) { return operand.builder()->DynamicReshape(operand, dim_sizes, new_size_bounds, - dims_are_dynamic); + dims_are_dynamic, expressions); } XlaOp ReshapeWithInferredDimension(XlaOp operand, absl::Span new_sizes, + absl::Span new_exprs, int64_t inferred_dimension) { - return operand.builder()->Reshape(operand, new_sizes, inferred_dimension); + return operand.builder()->Reshape(operand, new_sizes, new_exprs, + inferred_dimension); } XlaOp Collapse(const XlaOp operand, absl::Span dimensions) { @@ -5058,15 +5159,33 @@ XlaOp Slice(const XlaOp operand, absl::Span start_indices, strides); } +XlaOp Slice(const XlaOp operand, absl::Span start_indices, + absl::Span limit_indices, + absl::Span start_exprs, + absl::Span limit_exprs, + absl::Span strides) { + return operand.builder()->Slice(operand, start_indices, limit_indices, + start_exprs, limit_exprs, strides); +} + XlaOp SliceInDim(const XlaOp operand, int64_t start_index, int64_t limit_index, int64_t stride, int64_t dimno) { return operand.builder()->SliceInDim(operand, start_index, limit_index, stride, dimno); } +XlaOp SliceInDim(const XlaOp operand, int64_t start_index, int64_t limit_index, + DynExpr* start_expr, DynExpr* limit_expr, int64_t stride, + int64_t dimno) { + return operand.builder()->SliceInDim(operand, start_index, limit_index, + start_expr, limit_expr, stride, dimno); +} + XlaOp DynamicSlice(const XlaOp operand, absl::Span start_indices, - absl::Span slice_sizes) { - return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes); + absl::Span slice_sizes, + absl::Span slice_exprs) { + return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes, + slice_exprs); } XlaOp DynamicUpdateSlice(const XlaOp operand, const XlaOp update, @@ -6085,6 +6204,14 @@ XlaOp GetDimensionSize(const XlaOp operand, int64_t dimension) { return operand.builder()->GetDimensionSize(operand, dimension); } +XlaOp GetOuterBatchValue(XlaOp operand) { + XlaBuilder* builder = operand.builder(); + return CustomCall(builder, "GetOuterBatchValue", {operand}, + ShapeUtil::MakeShape(S32, {}), "", false, {}, + nullptr, CustomCallSchedule::SCHEDULE_NONE, + CustomCallApiVersion::API_VERSION_ORIGINAL); +} + XlaOp SetDimensionSize(const XlaOp operand, const XlaOp val, int64_t dimension) { return operand.builder()->SetDimensionSize(operand, val, dimension); diff --git a/third_party/xla/xla/hlo/builder/xla_builder.h b/third_party/xla/xla/hlo/builder/xla_builder.h index 8f479090d86b19..0c3c41e4cc59dc 100644 --- a/third_party/xla/xla/hlo/builder/xla_builder.h +++ b/third_party/xla/xla/hlo/builder/xla_builder.h @@ -519,10 +519,12 @@ class XlaBuilder { virtual XlaOp ConstantLiteral(const LiteralSlice& literal); - XlaOp Broadcast(XlaOp operand, absl::Span broadcast_sizes); + XlaOp Broadcast(XlaOp operand, absl::Span broadcast_sizes, + absl::Span broadcast_exprs = {}); XlaOp BroadcastInDim(XlaOp operand, absl::Span out_dim_size, - absl::Span broadcast_dimensions); + absl::Span broadcast_dimensions, + absl::Span out_dim_exp = {}); // This is an experimental API for creating the mhlo.dynamic_broadcast_in_dim // op from the XlaBuilder. This is only intended for export to MHLO or @@ -545,12 +547,17 @@ class XlaBuilder { XlaOp Reshape(XlaOp operand, absl::Span dimensions, int64_t inferred_dimension = -1); + XlaOp Reshape(XlaOp operand, absl::Span dimensions, + absl::Span expressions, + int64_t inferred_dimension = -1); + XlaOp Reshape(const Shape& shape, XlaOp operand, int64_t inferred_dimension = -1); XlaOp DynamicReshape(XlaOp operand, absl::Span dim_sizes, absl::Span new_size_bounds, - const std::vector& dims_are_dynamic); + const std::vector& dims_are_dynamic, + absl::Span expressions = {}); XlaOp MhloDynamicReshape(XlaOp operand, XlaOp output_shape, const Shape& shape); @@ -560,6 +567,13 @@ class XlaBuilder { XlaOp Slice(XlaOp operand, absl::Span start_indices, absl::Span limit_indices, absl::Span strides); + + XlaOp Slice(XlaOp operand, absl::Span start_indices, + absl::Span limit_indices, + absl::Span start_exprs, + absl::Span limit_exprs, + absl::Span strides); + virtual absl::StatusOr SliceInternal( const Shape& shape, XlaOp operand, absl::Span start_indices, @@ -568,8 +582,13 @@ class XlaBuilder { virtual XlaOp SliceInDim(XlaOp operand, int64_t start_index, int64_t limit_index, int64_t stride, int64_t dimno); + virtual XlaOp SliceInDim(XlaOp operand, int64_t start_index, + int64_t limit_index, DynExpr* start_expr, + DynExpr* limit_expr, int64_t stride, int64_t dimno); + XlaOp DynamicSlice(XlaOp operand, absl::Span start_indices, - absl::Span slice_sizes); + absl::Span slice_sizes, + absl::Span slice_exprs = {}); virtual absl::StatusOr DynamicSliceInternal( const Shape& shape, XlaOp operand, absl::Span start_indices, absl::Span slice_sizes); @@ -1244,11 +1263,13 @@ class XlaBuilder { const LiteralSlice& literal); friend XlaOp Broadcast(XlaOp operand, - absl::Span broadcast_sizes); + absl::Span broadcast_sizes, + absl::Span broadcast_expressions); friend XlaOp BroadcastInDim(XlaOp operand, absl::Span out_dim_size, - absl::Span broadcast_dimensions); + absl::Span broadcast_dimensions, + absl::Span out_dim_exp); friend XlaOp MhloDynamicBroadcastInDim( XlaOp operand, XlaOp output_dimensions, @@ -1265,18 +1286,22 @@ class XlaBuilder { friend XlaOp Reshape(XlaOp operand, absl::Span dimensions); + friend XlaOp Reshape(XlaOp operand, absl::Span dimensions, + absl::Span expressions); + friend XlaOp Reshape(const Shape& shape, XlaOp operand); friend XlaOp DynamicReshape(XlaOp operand, absl::Span dim_sizes, absl::Span new_size_bounds, - const std::vector& dims_are_dynamic); + const std::vector& dims_are_dynamic, + absl::Span expressions); friend XlaOp MhloDynamicReshape(XlaOp operand, XlaOp output_shape, const Shape& shape); - friend XlaOp ReshapeWithInferredDimension(XlaOp operand, - absl::Span new_sizes, - int64_t inferred_dimension); + friend XlaOp ReshapeWithInferredDimension( + XlaOp operand, absl::Span new_sizes, + absl::Span new_exprs, int64_t inferred_dimension); friend XlaOp Collapse(XlaOp operand, absl::Span dimensions); @@ -1284,12 +1309,23 @@ class XlaBuilder { absl::Span limit_indices, absl::Span strides); + friend XlaOp Slice(XlaOp operand, absl::Span start_indices, + absl::Span limit_indices, + absl::Span start_exprs, + absl::Span limit_exprs, + absl::Span strides); + friend XlaOp SliceInDim(XlaOp operand, int64_t start_index, int64_t limit_index, int64_t stride, int64_t dimno); + friend XlaOp SliceInDim(XlaOp operand, int64_t start_index, + int64_t limit_index, DynExpr* start_expr, + DynExpr* limit_expr, int64_t stride, int64_t dimno); + friend XlaOp DynamicSlice(XlaOp operand, absl::Span start_indices, - absl::Span slice_sizes); + absl::Span slice_sizes, + absl::Span slice_exprs); friend XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, absl::Span start_indices); @@ -1975,7 +2011,8 @@ XlaOp ConstantR1(XlaBuilder* builder, int64_t length, NativeT value); // The new dimensions index into copies of the operand, i.e. // // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] -XlaOp Broadcast(XlaOp operand, absl::Span broadcast_sizes); +XlaOp Broadcast(XlaOp operand, absl::Span broadcast_sizes, + absl::Span broadcast_exprs = {}); // This op broadcasts the `operand` to an output with the given `shape`. // `broadcast_dimensions` are the dimensions to be broadcasting into, i.e., the @@ -1993,7 +2030,8 @@ XlaOp Broadcast(XlaOp operand, absl::Span broadcast_sizes); // {{1 , 1}, // {2 , 2}} XlaOp BroadcastInDim(XlaOp operand, absl::Span out_dim_size, - absl::Span broadcast_dimensions); + absl::Span broadcast_dimensions, + absl::Span out_dim_exp = {}); // This is an experimental API for creating the mhlo.dynamic_broadcast_in_dim // op from the XlaBuilder. This is only intended for export to MHLO or @@ -2038,7 +2076,8 @@ XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno, // dimension dimension if dims_are_dynamic[i] is true. XlaOp DynamicReshape(XlaOp operand, absl::Span dim_sizes, absl::Span new_size_bounds, - const std::vector& dims_are_dynamic); + const std::vector& dims_are_dynamic, + absl::Span expressions); // This is an experimental API for creating the mhlo.dynamic_reshape op from the // XlaBuilder. This is only intended for export to MHLO or StableHLO, and cannot @@ -2050,6 +2089,9 @@ XlaOp MhloDynamicReshape(XlaOp operand, XlaOp output_shape, const Shape& shape); // dimension sizes. Conceptually, this is a limited form of "shape casting". XlaOp Reshape(XlaOp operand, absl::Span dimensions); +XlaOp Reshape(XlaOp operand, absl::Span dimensions, + absl::Span expressions); + // Enqueues a Reshape op that uses an explicit target shape. XlaOp Reshape(const Shape& shape, XlaOp operand); @@ -2059,6 +2101,7 @@ XlaOp Reshape(const Shape& shape, XlaOp operand); // is a dynamic dimension in the output, it must be the inferred dimension. XlaOp ReshapeWithInferredDimension(XlaOp operand, absl::Span new_sizes, + absl::Span new_exprs, int64_t inferred_dimension); // Wrapper for Reshape. @@ -2096,6 +2139,12 @@ XlaOp Slice(XlaOp operand, absl::Span start_indices, absl::Span limit_indices, absl::Span strides); +XlaOp Slice(XlaOp operand, absl::Span start_indices, + absl::Span limit_indices, + absl::Span start_exprs, + absl::Span limit_exprs, + absl::Span strides); + // Enqueues a slice operation in a given dimension, taking all other // dimensions as they are; e.g. if dimno is 1 from start_index 2 to // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand @@ -2105,6 +2154,10 @@ XlaOp Slice(XlaOp operand, absl::Span start_indices, XlaOp SliceInDim(XlaOp operand, int64_t start_index, int64_t limit_index, int64_t stride, int64_t dimno); +XlaOp SliceInDim(XlaOp operand, int64_t start_index, int64_t limit_index, + DynExpr* start_expr, DynExpr* limit_expr, + int64_t stride, int64_t dimno); + // Enqueues a slice operation onto the computation that slices the 'operand' // from dynamic start indices which are passed in 'start_indices'. // The size of the slice in each dimension is passed in 'slice_sizes', @@ -2116,7 +2169,8 @@ XlaOp SliceInDim(XlaOp operand, int64_t start_index, int64_t limit_index, // Slice index calculations are computed modulo input dimension sizes to // prevent dynamic start indices from generating out-of-bound array accesses. XlaOp DynamicSlice(XlaOp operand, absl::Span start_indices, - absl::Span slice_sizes); + absl::Span slice_sizes, + absl::Span slice_exprs = {}); // Enqueues a dynamic update slice operation onto the computation, which // updates a slice of 'operand' with 'update' at dynamic 'start_indices'. @@ -3055,6 +3109,8 @@ XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean, // array shaped. XlaOp GetDimensionSize(XlaOp operand, int64_t dimension); +XlaOp GetOuterBatchValue(XlaOp operand); + // Sets the size of the given dimension of the operand. The operand must be // array shaped. The result will have the same shape as the operand, but the // given dimension will be dynamic (if not already). diff --git a/third_party/xla/xla/hlo/pass/hlo_pass_pipeline.cc b/third_party/xla/xla/hlo/pass/hlo_pass_pipeline.cc index 07cf9a8c4bce2c..11eeefa948fa1b 100644 --- a/third_party/xla/xla/hlo/pass/hlo_pass_pipeline.cc +++ b/third_party/xla/xla/hlo/pass/hlo_pass_pipeline.cc @@ -250,6 +250,7 @@ absl::StatusOr HloPassPipeline::RunPassesInternal( } TF_RETURN_IF_ERROR(status); } + if (!pass->IsPassPipeline()) { compilation_stats_->EndPass(pass_name); } diff --git a/third_party/xla/xla/hlo/transforms/collectives/all_gather_combiner.cc b/third_party/xla/xla/hlo/transforms/collectives/all_gather_combiner.cc index 61b261248e0e2c..aa781b082c8823 100644 --- a/third_party/xla/xla/hlo/transforms/collectives/all_gather_combiner.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/all_gather_combiner.cc @@ -124,9 +124,9 @@ absl::Status CombineAllGathers(absl::Span to_combine, (*perm)[ag->all_gather_dimension()]); // Bitcast operand and update output shape. + auto sh = ShapeUtil::PermuteDimensions(*perm, operand_shape); operands.back() = - computation.AddInstruction(HloInstruction::CreateBitcast( - ShapeUtil::PermuteDimensions(*perm, operand_shape), operand)); + computation.AddInstruction(HloInstruction::CreateBitcast(sh, operand)); output_shapes.back() = ShapeUtil::PermuteDimensions(*perm, hlo->shape()); } } diff --git a/third_party/xla/xla/hlo/transforms/expanders/bitcast_dtypes_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/bitcast_dtypes_expander.cc index ddb505801d92ca..02dd7813d033de 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/bitcast_dtypes_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/bitcast_dtypes_expander.cc @@ -80,10 +80,13 @@ absl::StatusOr BitcastDtypesExpander::ExpandInstruction( broadcasted_input_shape.push_back(input_bit_width / output_bit_width); reshaped_input_shape.push_back(1); int64_t output_bit_width_mask = (int64_t{1} << output_bit_width) - 1; - - TF_ASSIGN_OR_RETURN(input, - BroadcastTo(Reshape(input, reshaped_input_shape), - broadcasted_input_shape)); + std::vector reshaped_input_exprs( + from_shape.expressions().begin(), from_shape.expressions().end()); + reshaped_input_exprs.push_back(DynExpr::_(1)); + TF_ASSIGN_OR_RETURN( + input, BroadcastTo( + Reshape(input, reshaped_input_shape, reshaped_input_exprs), + broadcasted_input_shape)); input = BitcastConvertType(input, input_logical_type); TF_ASSIGN_OR_RETURN(Shape input_shape, b.GetShape(input)); XlaOp iota = Iota(&b, input_shape, input_shape.dimensions().size() - 1); diff --git a/third_party/xla/xla/hlo/transforms/expanders/cholesky_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/cholesky_expander.cc index b5df50e1956c90..7b11a211ee3e02 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/cholesky_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/cholesky_expander.cc @@ -218,9 +218,9 @@ XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64_t block_size, l = UpdateSliceInMinorDims(l, update, {i + k, i}); } } - return Select( - BroadcastInDim(seen_error, a_shape.dimensions(), error_dim_indices), - FullLike(l, std::numeric_limits::quiet_NaN()), l); + return Select(BroadcastInDim(seen_error, a_shape.dimensions(), + error_dim_indices, a_shape.expressions()), + FullLike(l, std::numeric_limits::quiet_NaN()), l); }); } diff --git a/third_party/xla/xla/hlo/transforms/expanders/dot_decomposer.cc b/third_party/xla/xla/hlo/transforms/expanders/dot_decomposer.cc index a3787d88ebbd93..a7d62392dec79a 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/dot_decomposer.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/dot_decomposer.cc @@ -80,26 +80,42 @@ absl::Status CanonicalizeDot(HloDotInstruction* original_dot) { lhs_non_contracting_dims.reserve(num_lhs_non_contracting_dims); int64_t lhs_contracting_size = 1; bool lhs_contracting_dynamic = false; + int64_t lhs_contracting_multiplier_accu = 1; + DynExpr* lhs_contracting_expression = DynExpr::one; int64_t lhs_non_contracting_size = 1; bool lhs_non_contracting_dynamic = false; + int64_t lhs_non_contracting_multiplier_accu = 1; + DynExpr* lhs_non_contracting_expression = DynExpr::one; std::vector batch_dim_sizes; batch_dim_sizes.reserve(num_batch_dims); std::vector batch_dynamic_dims; batch_dynamic_dims.reserve(num_batch_dims); + std::vector batch_expressions; + batch_expressions.reserve(num_batch_dims); + + bool lhs_contracting_is_static = true; + bool lhs_non_contracting_is_static = true; + for (int64_t i = 0; i < lhs_rank; ++i) { if (absl::c_linear_search(original_dnums.lhs_contracting_dimensions(), i)) { lhs_contracting_size *= lhs_shape.dimensions(i); lhs_contracting_dynamic |= lhs_shape.is_dynamic_dimension(i); + lhs_contracting_expression = + (*lhs_contracting_expression) * (*lhs_shape.expressions(i)); } else if (absl::c_linear_search(original_dnums.lhs_batch_dimensions(), i)) { batch_dim_sizes.push_back(lhs_shape.dimensions(i)); batch_dynamic_dims.push_back(lhs_shape.is_dynamic_dimension(i)); + batch_expressions.push_back(lhs_shape.expressions(i)); } else { lhs_non_contracting_dims.push_back(i); lhs_non_contracting_size *= lhs_shape.dimensions(i); lhs_non_contracting_dynamic |= lhs_shape.is_dynamic_dimension(i); + lhs_non_contracting_expression = + (*lhs_non_contracting_expression) * (*lhs_shape.expressions(i)); } } + // The canonical form of the lhs is // [BatchDims, NonContractingDimsProduct, ContractingsDimsProduct] // If NonContractingDimsProduct is 1, it is omitted. @@ -123,18 +139,21 @@ absl::Status CanonicalizeDot(HloDotInstruction* original_dot) { std::vector lhs_reshape_dims = batch_dim_sizes; std::vector lhs_reshape_dynamic_dims = batch_dynamic_dims; + std::vector lhs_reshape_expressions = batch_expressions; if (lhs_non_contracting_size > 1) { lhs_reshape_dims.push_back(lhs_non_contracting_size); lhs_reshape_dynamic_dims.push_back(lhs_non_contracting_dynamic); + lhs_reshape_expressions.push_back(lhs_non_contracting_expression->s()); } lhs_reshape_dims.push_back(lhs_contracting_size); lhs_reshape_dynamic_dims.push_back(lhs_contracting_dynamic); + lhs_reshape_expressions.push_back(lhs_contracting_expression->s()); // Reshape the contracting and non-contracting dimensions together. + auto sh_lhs = ShapeUtil::MakeShape(lhs_shape.element_type(), lhs_reshape_dims, + lhs_reshape_dynamic_dims, + lhs_reshape_expressions); HloInstruction* reshaped_lhs = computation->AddInstruction( - HloInstruction::CreateReshape( - ShapeUtil::MakeShape(lhs_shape.element_type(), lhs_reshape_dims, - lhs_reshape_dynamic_dims), - transposed_lhs), + HloInstruction::CreateReshape(sh_lhs, transposed_lhs), &transposed_lhs->metadata()); const auto& rhs_shape = original_dot->operand(1)->shape(); @@ -145,17 +164,29 @@ absl::Status CanonicalizeDot(HloDotInstruction* original_dot) { rhs_non_contracting_dims.reserve(num_rhs_non_contracting_dims); int64_t rhs_non_contracting_size = 1; bool rhs_non_contracting_dynamic = false; + int64_t rhs_non_contracting_multiplier_accu = 1; + DynExpr* rhs_non_contracting_expression = DynExpr::one; int64_t rhs_contracting_size = 1; bool rhs_contracting_dynamic = false; + int64_t rhs_contracting_multiplier_accu = 1; + DynExpr* rhs_contracting_expression = DynExpr::one; + + bool rhs_contracting_is_static = true; + bool rhs_non_contracting_is_static = true; + for (int64_t i = 0; i < rhs_rank; ++i) { if (absl::c_linear_search(original_dnums.rhs_contracting_dimensions(), i)) { rhs_contracting_size *= rhs_shape.dimensions(i); rhs_contracting_dynamic |= rhs_shape.is_dynamic_dimension(i); + rhs_contracting_expression = + (*rhs_contracting_expression) * (*rhs_shape.expressions(i)); } else if (!absl::c_linear_search(original_dnums.rhs_batch_dimensions(), i)) { rhs_non_contracting_dims.push_back(i); rhs_non_contracting_size *= rhs_shape.dimensions(i); rhs_non_contracting_dynamic |= rhs_shape.is_dynamic_dimension(i); + rhs_non_contracting_expression = + (*rhs_non_contracting_expression) * (*rhs_shape.expressions(i)); } } @@ -184,27 +215,35 @@ absl::Status CanonicalizeDot(HloDotInstruction* original_dot) { rhs_reshape_dims.push_back(rhs_contracting_size); std::vector rhs_reshape_dynamic_dims = batch_dynamic_dims; rhs_reshape_dynamic_dims.push_back(rhs_contracting_dynamic); + std::vector rhs_reshape_expressions = batch_expressions; + rhs_reshape_expressions.push_back(rhs_contracting_expression->s()); if (rhs_non_contracting_size > 1) { rhs_reshape_dims.push_back(rhs_non_contracting_size); rhs_reshape_dynamic_dims.push_back(rhs_non_contracting_dynamic); + rhs_reshape_expressions.push_back(rhs_non_contracting_expression->s()); } // Reshape the contracting and non-contracting dimensions together. + auto sh_rhs = ShapeUtil::MakeShape(rhs_shape.element_type(), rhs_reshape_dims, + rhs_reshape_dynamic_dims, + rhs_reshape_expressions); HloInstruction* reshaped_rhs = computation->AddInstruction( HloInstruction::CreateReshape( - ShapeUtil::MakeShape(rhs_shape.element_type(), rhs_reshape_dims, - rhs_reshape_dynamic_dims), + sh_rhs, transposed_rhs), &transposed_rhs->metadata()); std::vector dot_dims = batch_dim_sizes; std::vector dot_dynamic_dims = batch_dynamic_dims; + std::vector dot_expressions = batch_expressions; if (lhs_non_contracting_size > 1) { dot_dims.push_back(lhs_non_contracting_size); dot_dynamic_dims.push_back(lhs_non_contracting_dynamic); + dot_expressions.push_back(lhs_non_contracting_expression->s()); } if (rhs_non_contracting_size > 1) { dot_dims.push_back(rhs_non_contracting_size); dot_dynamic_dims.push_back(rhs_non_contracting_dynamic); + dot_expressions.push_back(rhs_non_contracting_expression->s()); } DotDimensionNumbers dot_dnums; @@ -251,12 +290,13 @@ absl::Status CanonicalizeDot(HloDotInstruction* original_dot) { HloInstruction::CreateReshape(result_shape, meta), &meta->metadata()); sparse_meta.push_back(meta); } + auto sh_dot = + ShapeUtil::MakeShape(original_dot->shape().element_type(), dot_dims, + dot_dynamic_dims, dot_expressions); HloInstruction* dot = computation->AddInstruction(HloInstruction::CreateDot( - ShapeUtil::MakeShape(original_dot->shape().element_type(), dot_dims, - dot_dynamic_dims), - reshaped_lhs, reshaped_rhs, dot_dnums, original_dot->precision_config(), - sparsity, sparse_meta)); + sh_dot, reshaped_lhs, reshaped_rhs, dot_dnums, + original_dot->precision_config(), sparsity, sparse_meta)); original_dot->SetupDerivedInstruction(dot); std::unique_ptr replacement = diff --git a/third_party/xla/xla/hlo/transforms/expanders/eigh_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/eigh_expander.cc index 33752d60cae8ce..6ddf023ffd77a7 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/eigh_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/eigh_expander.cc @@ -158,8 +158,10 @@ void ApplyJacobiRotationOverRows(Eigh2x2 rotation, XlaOp& tl, XlaOp& tr, Shape shape = tl.builder()->GetShape(tl).value(); std::vector broadcast_dims(shape.dimensions().size() - 1); absl::c_iota(broadcast_dims, 0); - auto c = BroadcastInDim(rotation.c, shape.dimensions(), broadcast_dims); - auto s = BroadcastInDim(rotation.s, shape.dimensions(), broadcast_dims); + auto c = BroadcastInDim(rotation.c, shape.dimensions(), broadcast_dims, + shape.expressions()); + auto s = BroadcastInDim(rotation.s, shape.dimensions(), broadcast_dims, + shape.expressions()); auto s_conj = MaybeConjugate(s, true); std::tie(tl, tr, bl, br) = @@ -179,8 +181,10 @@ void ApplyJacobiRotationOverCols(Eigh2x2 rotation, XlaOp& tl, XlaOp& tr, std::vector broadcast_dims(shape.dimensions().size() - 1); absl::c_iota(broadcast_dims, 0); broadcast_dims.back() = shape.dimensions().size() - 1; - auto c = BroadcastInDim(rotation.c, shape.dimensions(), broadcast_dims); - auto s = BroadcastInDim(rotation.s, shape.dimensions(), broadcast_dims); + auto c = BroadcastInDim(rotation.c, shape.dimensions(), broadcast_dims, + shape.expressions()); + auto s = BroadcastInDim(rotation.s, shape.dimensions(), broadcast_dims, + shape.expressions()); auto s_conj = MaybeConjugate(s, true); std::tie(tl, tr, bl, br) = @@ -365,11 +369,12 @@ absl::Status EighExpander::SortByEigenvalues(XlaOp& v, XlaOp& w) { TF_ASSIGN_OR_RETURN(Shape w_shape, builder->GetShape(w)); const int64_t num_dims = v_shape.dimensions().size(); auto dimensions = v_shape.dimensions(); + auto expressions = v_shape.expressions(); std::vector broadcast_dims(num_dims - 1); std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); broadcast_dims[num_dims - 2] = num_dims - 1; - w = BroadcastInDim(w, dimensions, broadcast_dims); + w = BroadcastInDim(w, dimensions, broadcast_dims, expressions); XlaOp sort_result = Sort({w, v}, diff --git a/third_party/xla/xla/hlo/transforms/expanders/qr_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/qr_expander.cc index dcf329f7c6d8b9..27b654587ba01a 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/qr_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/qr_expander.cc @@ -58,6 +58,15 @@ std::vector ConcatVectors(absl::Span xs, return output; } +std::vector ConcatEVectors(absl::Span xs, + absl::Span ys) { + std::vector output; + output.reserve(xs.size() + ys.size()); + std::copy(xs.begin(), xs.end(), std::back_inserter(output)); + std::copy(ys.begin(), ys.end(), std::back_inserter(output)); + return output; +} + // Computes sqrt(x^2 + y^2 + ...), avoiding overflow/underflow. // e.g. for 3 arguments: // def norm(x, y, z): @@ -220,11 +229,15 @@ absl::StatusOr QrExpander::QrBlock( const int64_t m = ShapeUtil::GetDimension(a_shape, -2); const int64_t n = ShapeUtil::GetDimension(a_shape, -1); + DynExpr* m_exp = ShapeUtil::GetExpression(a_shape, -2); + DynExpr* n_exp = ShapeUtil::GetExpression(a_shape, -1); const int64_t num_batch_dims = num_dims - 2; std::vector batch_dims(num_batch_dims); + std::vector batch_exprs(num_batch_dims); for (int i = 0; i < num_batch_dims; ++i) { batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); + batch_exprs[i] = ShapeUtil::GetExpression(a_shape, i); } std::vector batch_dim_indices(num_batch_dims); @@ -248,9 +261,12 @@ absl::StatusOr QrExpander::QrBlock( minor_dim + 1); std::vector shape = batch_dims; + std::vector exprs = batch_exprs; shape.push_back(1); shape.push_back(m); - auto v_broadcast = Reshape(v, shape); + exprs.push_back(DynExpr::one); + exprs.push_back(m_exp); + auto v_broadcast = Reshape(v, shape, exprs); // a[:, j+1:] -= np.conj(tau) * (v[:, np.newaxis] @ // (np.conj(v[np.newaxis, :]) @ a[:, j+1:])) // We use masking rather than a loop-variant shape to handle the j+1: @@ -263,7 +279,8 @@ absl::StatusOr QrExpander::QrBlock( // a[j, j] = beta // a[j+1:,j] = v[j+1:] - auto iota = Reshape(Iota(a.builder(), S32, m), {m, 1}); + auto iota = + Reshape(Iota(a.builder(), S32, m), {m, 1}, {m_exp, DynExpr::one}); auto predecessor_mask = ConvertElementType(Lt(iota, j), type); auto mask = Broadcast(ConvertElementType(Eq(iota, j), type), std::vector(batch_dims.size(), 1)); @@ -279,7 +296,8 @@ absl::StatusOr QrExpander::QrBlock( std::vector dim_ids(num_dims); std::iota(dim_ids.begin(), dim_ids.end(), 0); new_x = BroadcastInDim(new_x, ConcatVectors(batch_dims, {m, n}), - /*broadcast_dimensions=*/dim_ids); + /*broadcast_dimensions=*/dim_ids, + ConcatEVectors(batch_exprs, {m_exp, n_exp})); a = Select(Eq(iota_mn, j), new_x, a); // taus[j] = tau diff --git a/third_party/xla/xla/hlo/transforms/expanders/rng_bit_generator_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/rng_bit_generator_expander.cc index d41aaf49be5d32..2a6f428b65aa82 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/rng_bit_generator_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/rng_bit_generator_expander.cc @@ -66,7 +66,7 @@ RngBitGeneratorExpander::GetGeneratorComputation(const Shape& data_shape, XlaBuilder builder("rng"); XlaOp state_param = Parameter(&builder, 0, state_shape, "state"); - XlaOp key_op = Reshape(Slice(state_param, {0}, {1}, {1}), {}); + XlaOp key_op = Reshape(Slice(state_param, {0}, {1}, {1}), {}, {}); RngOutput output; switch (algorithm) { case RandomAlgorithm::RNG_THREE_FRY: @@ -83,8 +83,8 @@ RngBitGeneratorExpander::GetGeneratorComputation(const Shape& data_shape, RandomAlgorithm_Name(algorithm)); } - XlaOp final_state = - ConcatInDim(&builder, {Reshape(key_op, {1}), output.state}, 0); + XlaOp final_state = ConcatInDim( + &builder, {Reshape(key_op, {1}, {DynExpr::one}), output.state}, 0); Tuple(&builder, {final_state, output.value}); TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build()); TF_ASSIGN_OR_RETURN(HloComputation * new_computation, diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index 8391636e00daac..90bf284e206a77 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -2673,6 +2673,8 @@ LogicalResult ExportXlaOp(DynamicReshapeOp op, OpLoweringContext ctx) { SmallVector dimSizes; SmallVector newSizeBounds; std::vector dimsAreDynamic; + std::vector dimExpressions; + for (auto i = 0; i < resultType.getRank(); ++i) { auto runtimeSizeX1 = xla::Slice(outputShape, {i}, {i + 1}, {1}); dimSizes.push_back(xla::Reshape(runtimeSizeX1, {})); @@ -2683,9 +2685,10 @@ LogicalResult ExportXlaOp(DynamicReshapeOp op, OpLoweringContext ctx) { return op->emitOpError() << "unbounded dynamism is not supported"; newSizeBounds.push_back(hlo::isStaticDimSize(dimSize) ? dimSize : dimBound); dimsAreDynamic.push_back(!hlo::isStaticDimSize(dimSize)); + dimExpressions.push_back(xla::DynExpr::_(-40)); // Don't know. } - value_map[op] = - xla::DynamicReshape(operand, dimSizes, newSizeBounds, dimsAreDynamic); + value_map[op] = xla::DynamicReshape(operand, dimSizes, newSizeBounds, + dimsAreDynamic, dimExpressions); return success(); } @@ -2696,7 +2699,8 @@ LogicalResult ExportXlaOp(ReshapeOp op, OpLoweringContext ctx) { return failure(); value_map[op] = - xla::Reshape(operand, xla::TypeToShape(op.getType()).dimensions()); + xla::Reshape(operand, xla::TypeToShape(op.getType()).dimensions(), + xla::TypeToShape(op.getType()).expressions()); return success(); } @@ -3004,6 +3008,7 @@ LogicalResult ExportXlaOp(DynamicReshapeOp op, OpLoweringContext ctx) { SmallVector dimSizes; SmallVector newSizeBounds; std::vector dimsAreDynamic; + std::vector dimExpressions; for (auto i = 0; i < resultType.getRank(); ++i) { auto runtimeSizeX1 = xla::Slice(outputShape, {i}, {i + 1}, {1}); dimSizes.push_back(xla::Reshape(runtimeSizeX1, {})); @@ -3014,9 +3019,11 @@ LogicalResult ExportXlaOp(DynamicReshapeOp op, OpLoweringContext ctx) { return op->emitOpError() << "unbounded dynamism is not supported"; newSizeBounds.push_back(hlo::isStaticDimSize(dimSize) ? dimSize : dimBound); dimsAreDynamic.push_back(!hlo::isStaticDimSize(dimSize)); + dimExpressions.push_back(xla::DynExpr::_(-50)); // Don't know } value_map[op] = - xla::DynamicReshape(operand, dimSizes, newSizeBounds, dimsAreDynamic); + xla::DynamicReshape(operand, dimSizes, newSizeBounds, dimsAreDynamic, + dimExpressions); return success(); } @@ -4558,7 +4565,8 @@ LogicalResult ExportXlaOp(ReshapeOp op, OpLoweringContext ctx) { return failure(); value_map[op] = - xla::Reshape(operand, xla::TypeToShape(op.getType()).dimensions()); + xla::Reshape(operand, xla::TypeToShape(op.getType()).dimensions(), + xla::TypeToShape(op.getType()).expressions()); return success(); } diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.cc index 7bfbbff39956ce..9369b89d77ec61 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.cc @@ -150,6 +150,7 @@ Shape TypeToShape(mlir::Type type) { llvm::SmallVector shape(rank, mlir::ShapedType::kDynamic); std::vector is_dynamic(rank, false); + std::vector expressions(rank, DynExpr::_(-60)); for (int64_t dim = 0; dim < rank; ++dim) { int64_t size = t.getDimSize(dim); if (size == ShapedType::kDynamic) { @@ -191,7 +192,8 @@ Shape TypeToShape(mlir::Type type) { return sparse_shape; } - return ShapeUtil::MakeShape(primitive_type, shape, is_dynamic); + return ShapeUtil::MakeShape(primitive_type, shape, is_dynamic, + expressions); } else if (auto tuple_type = mlir::dyn_cast(type)) { llvm::SmallVector shapes; shapes.reserve(tuple_type.size()); diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 24247e60993937..877d55221652af 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -1156,6 +1156,16 @@ cc_library( ], ) +cc_library( + name = "executable_run_options_offset", + srcs = ["executable_run_options_offset.cc"], + hdrs = ["executable_run_options_offset.h"], + copts = tsl_copts(), + deps = [ + "//xla:executable_run_options", + ], +) + cc_library( name = "runtime_conv3d", srcs = ["runtime_conv3d.cc"], diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc index fcffdafcf5e5ee..abcfae5fccda0a 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc @@ -475,7 +475,7 @@ std::unique_ptr> CreateSimplificationPipeline( } // Needs to happen after algebraic simplifier. - pipeline->AddPass(); + // pipeline->AddPass(); // BatchNormExpander can create zero-sized ops, so zero-sized HLO // elimination has to come after that pass. diff --git a/third_party/xla/xla/service/cpu/cpu_executable.cc b/third_party/xla/xla/service/cpu/cpu_executable.cc index 158a4000d27722..a215511c5aba5e 100644 --- a/third_party/xla/xla/service/cpu/cpu_executable.cc +++ b/third_party/xla/xla/service/cpu/cpu_executable.cc @@ -163,14 +163,34 @@ static absl::StatusOr MemoryForAllocation( se::DeviceMemoryAllocator* memory_allocator, int device_ordinal) { VLOG(3) << allocation.ToString(); if (allocation.is_entry_computation_parameter()) { - se::DeviceMemoryBase out = arguments[allocation.parameter_number()] + se::DeviceMemoryBase param_mem = arguments[allocation.parameter_number()] .Buffer(allocation.param_shape_index()) .AsDeviceMemoryBase(); - CHECK_LE(allocation.size(), out.size()) - << "Size mismatch on param " << allocation.parameter_number() - << " at shape index " << allocation.param_shape_index().ToString(); - VLOG(3) << "allocation is a parameter"; - return MaybeOwningDeviceMemory{out}; + + const int64_t compiled_bytes = allocation.size(); + const int64_t runtime_bytes = param_mem.size(); + + if(runtime_bytes == 0){ + return MaybeOwningDeviceMemory{param_mem}; + } + + if (compiled_bytes <= runtime_bytes) { + return MaybeOwningDeviceMemory{param_mem}; + } + + //padded owns that device memory + TF_ASSIGN_OR_RETURN( + se::OwningDeviceMemory padded, + memory_allocator->Allocate(device_ordinal, compiled_bytes)); + + void* dst = padded-> opaque(); + void* src = param_mem.opaque(); + + std::memcpy(dst, src, runtime_bytes); + //Fill rest of them with zeros. + std::memset(static_cast(dst) + runtime_bytes, 0, compiled_bytes - runtime_bytes); + return MaybeOwningDeviceMemory{std::move(padded)}; + } else if (allocation.is_constant()) { VLOG(3) << "allocation is a constant"; if (allocation.index() < constants.size()) { @@ -275,11 +295,28 @@ absl::Status CpuExecutable::ExecuteComputeFunction( return absl::OkStatus(); } +void PrintScalars(absl::Span buffers) { + for (int i = 0; i < buffers.size(); ++i) { + const se::DeviceMemoryBase& dmem = buffers[i].AsDeviceMemoryBase(); + if (dmem.opaque() && dmem.size() >= sizeof(int64_t)) { + int64_t val = 0; + std::memcpy(&val, dmem.opaque(), sizeof(val)); + std::cerr << "Buffer " << i << " scalar: " << val << std::endl; + } else { + std::cerr << "Buffer " << i << " empty or too small" << std::endl; + } + } +} + absl::Status CpuExecutable::ExecuteThunks( const ExecutableRunOptions* run_options, absl::Span buffers) { uint64_t start_ns = tsl::Env::Default()->NowNanos(); + #if defined(PRINT_BATCHSIZE) + PrintScalars(buffers); + #endif + size_t profile_counters_size = 0; int64_t* profile_counters = nullptr; @@ -318,7 +355,8 @@ absl::Status CpuExecutable::ExecuteThunks( intra_op_thread_pool, &task_runner, &collective_execute_params, - &custom_call_execute_params}; + &custom_call_execute_params, + run_options->batch_size()}; auto executed_event = thunks_->Execute(execute_params); tsl::BlockUntilReady(executed_event); @@ -339,13 +377,15 @@ absl::StatusOr CpuExecutable::CreateResultShapedBuffer( absl::Span buffers, absl::Span arguments) { se::Stream* stream = run_options->stream(); - ExecutionOutput result(/*on_device_shape=*/result_shape(), - run_options->allocator(), - stream->parent()->device_ordinal()); const HloInputOutputAliasConfig& input_output_alias = module().input_output_alias_config(); HloInstruction* root = hlo_module_->entry_computation()->root_instruction(); const Shape& root_shape = root->shape(); + // Use root_shape to initialize ExecutionOuput as the batch multiplier info + // is only attached the ROOT + ExecutionOutput result(/*on_device_shape=*/root_shape, //result_shape(), + run_options->allocator(), + stream->parent()->device_ordinal()); // Move se::OwningDeviceMemory values which contain the array(s) of the result // into the respective location in ScopedShapedBuffer which is returned to the diff --git a/third_party/xla/xla/service/cpu/dot_op_emitter.cc b/third_party/xla/xla/service/cpu/dot_op_emitter.cc index e2e4f914579dbb..e201b87a4a2a6a 100644 --- a/third_party/xla/xla/service/cpu/dot_op_emitter.cc +++ b/third_party/xla/xla/service/cpu/dot_op_emitter.cc @@ -160,6 +160,27 @@ bool CanEmitTiledLlvmIrGemm( return true; } +bool HasDynamicMatmulDims(const DotInfo& dot_info) { + const Shape& lhs_shape = dot_info.lhs_shape; + const Shape& rhs_shape = dot_info.rhs_shape; + const DotDimensionNumbers& dim_nums = dot_info.dim_nums; + + DynExpr* m_expr = lhs_shape.dimensions().size() <= 1 + ? DynExpr::one + : lhs_shape.expressions( + 1LL - dim_nums.lhs_contracting_dimensions(0)); + DynExpr* k_expr = + lhs_shape.expressions(dim_nums.lhs_contracting_dimensions(0)); + DynExpr* n_expr = rhs_shape.dimensions().size() <= 1 + ? DynExpr::one + : rhs_shape.expressions( + 1LL - dim_nums.rhs_contracting_dimensions(0)); + + return (m_expr != nullptr && m_expr->is_dynamic()) || + (k_expr != nullptr && k_expr->is_dynamic()) || + (n_expr != nullptr && n_expr->is_dynamic()); +} + // Returns dot implementation strategy for non-batch dot operations. DotImplementationStrategy GetNonBatchDotImplementationStrategy( const HloModuleConfig& config, const DotInfo& dot_info, @@ -173,6 +194,10 @@ DotImplementationStrategy GetNonBatchDotImplementationStrategy( dot_info.dim_nums.rhs_batch_dimensions_size() == 0) << "Dot operations must be non-batch"; + if (HasDynamicMatmulDims(dot_info)) { + return DotImplementationStrategy::kEigen; + } + // Any Matrix-Vector product of floating point or integral type, or // a transpose-dot fusion of the same can be lowered to a tiled LLVM // IR implementation. @@ -253,6 +278,12 @@ class DotOpEmitter { // The number of columns on the RHS. int64_t n; + DynExpr* m_expr; + + DynExpr* k_expr; + + DynExpr* n_expr; + // True if the LHS matrix is column major. bool lhs_column_major; @@ -858,16 +889,20 @@ absl::Status DotOpEmitter::EmitCallToRuntime() { if (!mat_mult_dims.lhs_column_major) { std::swap(mat_mult_dims.m, mat_mult_dims.n); + std::swap(mat_mult_dims.m_expr, mat_mult_dims.n_expr); std::swap(lhs, rhs); std::swap(transpose_lhs, transpose_rhs); } - b_->CreateCall(matmul_func, - {executable_run_options_value_, target_array_.GetBasePointer(), - lhs->GetBasePointer(), rhs->GetBasePointer(), - b_->getInt64(mat_mult_dims.m), b_->getInt64(mat_mult_dims.n), - b_->getInt64(mat_mult_dims.k), b_->getInt32(transpose_lhs), - b_->getInt32(transpose_rhs)}); + llvm::Value* m_val = xla::llvm_ir::EmitExpression(b_, mat_mult_dims.m_expr); + llvm::Value* n_val = xla::llvm_ir::EmitExpression(b_, mat_mult_dims.n_expr); + llvm::Value* k_val = xla::llvm_ir::EmitExpression(b_, mat_mult_dims.k_expr); + + b_->CreateCall( + matmul_func, + {executable_run_options_value_, target_array_.GetBasePointer(), + lhs->GetBasePointer(), rhs->GetBasePointer(), m_val, n_val, k_val, + b_->getInt32(transpose_lhs), b_->getInt32(transpose_rhs)}); return absl::OkStatus(); } @@ -942,18 +977,28 @@ absl::Status DotOpEmitter::EmitCallToBatchRuntime() { if (!mat_mult_dims.lhs_column_major) { std::swap(mat_mult_dims.m, mat_mult_dims.n); + std::swap(mat_mult_dims.m_expr, mat_mult_dims.n_expr); std::swap(lhs, rhs); std::swap(transpose_lhs, transpose_rhs); } + llvm::Value* m_val = xla::llvm_ir::EmitExpression(b_, mat_mult_dims.m_expr); + llvm::Value* n_val = xla::llvm_ir::EmitExpression(b_, mat_mult_dims.n_expr); + llvm::Value* k_val = xla::llvm_ir::EmitExpression(b_, mat_mult_dims.k_expr); + DynExpr* batch_size_expr = lhs_shape.expressions(0); + if (batch_size_expr == nullptr) { + batch_size_expr = DynExpr::_(lhs_shape.dimensions(0)); + } + llvm::Value* batch_size_val = + xla::llvm_ir::EmitExpression(b_, batch_size_expr); + VLOG(1) << "Batch dot emitted with runtime:" << fn_name; b_->CreateCall( matmul_func, {executable_run_options_value_, target_array_.GetBasePointer(), - lhs->GetBasePointer(), rhs->GetBasePointer(), - b_->getInt64(mat_mult_dims.m), b_->getInt64(mat_mult_dims.n), - b_->getInt64(mat_mult_dims.k), b_->getInt64(lhs_shape.dimensions(0)), + lhs->GetBasePointer(), rhs->GetBasePointer(), m_val, n_val, k_val, + batch_size_val, b_->getInt32(static_cast(transpose_lhs)), b_->getInt32(static_cast(transpose_rhs))}); return absl::OkStatus(); @@ -983,6 +1028,14 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { /*n=*/rhs_shape.dimensions().size() <= 1 ? 1LL : rhs_shape.dimensions(1LL - dim_nums.rhs_contracting_dimensions(0)), + /*m_expr=*/lhs_shape.dimensions().size() <= 1 + ? DynExpr::one + : lhs_shape.expressions(1LL - dim_nums.lhs_contracting_dimensions(0)), + /*k_expr=*/ + lhs_shape.expressions(dim_nums.lhs_contracting_dimensions(0)), + /*n_expr=*/rhs_shape.dimensions().size() <= 1 + ? DynExpr::one + : rhs_shape.expressions(1LL - dim_nums.rhs_contracting_dimensions(0)), /*lhs_column_major=*/is_column_major(lhs_shape), /*lhs_canonical=*/lhs_shape.dimensions().size() <= 1 || dim_nums.lhs_contracting_dimensions(0) == 1, @@ -1014,6 +1067,14 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetBatchMatMultDims() const { /*n=*/rhs_shape.dimensions().size() <= 1 ? 1LL : rhs_shape.dimensions(2LL - dim_nums.rhs_contracting_dimensions(0)), + /*m_expr=*/lhs_shape.dimensions().size() <= 1 + ? DynExpr::one + : lhs_shape.expressions(2LL - dim_nums.lhs_contracting_dimensions(0)), + /*k_expr=*/ + lhs_shape.expressions(1LL + dim_nums.lhs_contracting_dimensions(0)), + /*n_expr=*/rhs_shape.dimensions().size() <= 1 + ? DynExpr::one + : rhs_shape.expressions(2LL - dim_nums.rhs_contracting_dimensions(0)), /*lhs_column_major=*/is_column_major(lhs_shape), /*lhs_canonical=*/lhs_shape.dimensions().size() <= 1 || dim_nums.lhs_contracting_dimensions(0) == 1, @@ -1093,22 +1154,34 @@ absl::Status EmitNonBatchDotOperation( Shape DropFirstDim(const Shape& shape) { absl::Span array_shape_dims(shape.dimensions()); + absl::Span array_shape_exprs(shape.expressions()); array_shape_dims.remove_prefix(1); - return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(), - array_shape_dims); + array_shape_exprs.remove_prefix(1); + return ShapeUtil::MakeShapeWithDescendingLayout( + shape.element_type(), array_shape_dims, array_shape_exprs); } Shape CollapseFirstNDims(const Shape& shape, int64_t n) { absl::Span input_shape_dims(shape.dimensions()); + absl::Span input_expressions(shape.expressions()); int64_t prefix_dim = std::accumulate(input_shape_dims.begin(), input_shape_dims.begin() + n, 1ll, std::multiplies()); + + DynExpr* prefix_expression = std::accumulate( + input_expressions.begin(), input_expressions.begin() + n, DynExpr::one, + [](DynExpr* acc, DynExpr* v) { return (*acc) * (*v); }); + DimensionVector result_dims; + std::vector result_expressions; result_dims.push_back(prefix_dim); + result_expressions.push_back(prefix_expression->s()); std::copy(input_shape_dims.begin() + n, input_shape_dims.end(), std::back_inserter(result_dims)); - return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(), - result_dims); + std::copy(input_expressions.begin() + n, input_expressions.end(), + std::back_inserter(result_expressions)); + return ShapeUtil::MakeShapeWithDescendingLayout( + shape.element_type(), result_dims, result_expressions); } llvm_ir::IrArray CollapseFirstNDims(llvm::IRBuilderBase* b, diff --git a/third_party/xla/xla/service/cpu/executable_run_options_offset.cc b/third_party/xla/xla/service/cpu/executable_run_options_offset.cc new file mode 100644 index 00000000000000..2d84ddfa2ea352 --- /dev/null +++ b/third_party/xla/xla/service/cpu/executable_run_options_offset.cc @@ -0,0 +1,28 @@ +#include "executable_run_options_offset.h" +#include "xla/executable_run_options.h" + +namespace xla::cpu { + +// Friend-injection trick to get a pointer-to-private-member for the *real* +// xla::ExecutableRunOptions::batch_size_. +template +struct Linker { + friend constexpr typename Tag::type get_offset(Tag) { return Ptr; } +}; + +struct BatchSizeTag { + using type = int64_t xla::ExecutableRunOptions::*; + friend constexpr type get_offset(BatchSizeTag); +}; + +// Instantiate template to expose &ExecutableRunOptions::batch_size_. +template struct Linker; + +size_t ExecutableRunOptionsBatchSizeOffset() { + auto ptr = get_offset(BatchSizeTag{}); + // Compute offset in bytes from null pointer. + return reinterpret_cast( + &(reinterpret_cast(0)->*ptr)); +} + +} // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/executable_run_options_offset.h b/third_party/xla/xla/service/cpu/executable_run_options_offset.h new file mode 100644 index 00000000000000..3e702454e76ec5 --- /dev/null +++ b/third_party/xla/xla/service/cpu/executable_run_options_offset.h @@ -0,0 +1,8 @@ +#pragma once +#include + +namespace xla::cpu{ + + size_t ExecutableRunOptionsBatchSizeOffset(); + +} diff --git a/third_party/xla/xla/service/cpu/ir_emitter.cc b/third_party/xla/xla/service/cpu/ir_emitter.cc index feca6552d243f8..2925ddf8127838 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter.cc @@ -2067,7 +2067,7 @@ absl::Status IrEmitter::HandleSlice(HloInstruction* slice) { const int64_t memcpy_elements = primitive_elements_per_logical_element * memcpy_logical_elements; - EmitTransferElements(memcpy_dest, memcpy_source, memcpy_elements, + EmitTransferElements(memcpy_dest, memcpy_source, DynExpr::_(memcpy_elements), slice->shape().element_type(), target_array, source_array); @@ -2358,6 +2358,24 @@ absl::Status IrEmitter::HandleSliceToDynamic(HloInstruction* hlo) { return EmitSliceToDynamic(hlo, source_arrays, target_array); } +absl::Status IrEmitter::HandleOuterBatchValue(HloInstruction* hlo) { + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo)); + + llvm_ir::IrArray out_array = GetIrArrayFor(hlo); + + llvm::Value* expr_value = + llvm_ir::EmitExpression(b(), hlo->operand(0)->shape().expressions(0)); + + auto it = emitted_value_.find(hlo); + if (it == emitted_value_.end()) { + LOG(ERROR) << "No buffer assigned for instruction " << hlo->name(); + } + llvm::Value* dest_ptr = it->second; + b()->CreateStore(expr_value, dest_ptr); + + return absl::OkStatus(); +} + absl::Status IrEmitter::HandlePadToStatic(HloInstruction* hlo) { TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo)); @@ -2806,6 +2824,10 @@ absl::Status IrEmitter::HandleOneDnnSoftmax(HloInstruction* custom_call) { #endif // INTEL_MKL absl::Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { + if (custom_call->custom_call_target() == "GetOuterBatchValue") { + return HandleOuterBatchValue(custom_call); + } + if (custom_call->custom_call_target() == "PadToStatic") { return HandlePadToStatic(custom_call); } @@ -3125,11 +3147,11 @@ absl::Status EmitFastConcatenate( // contiguous subregion in the target buffer starting at target_region_begin. llvm::Value* target_region_begin = target_array.EmitArrayElementAddress(target_index, &b, "target_region"); - int64_t byte_offset_into_target_region = 0; + llvm::Value* byte_offset_into_target_region = b.getInt64(0); - int64_t inner_dims_product = absl::c_accumulate( - inner_dims, int64_t{1}, [&](int64_t product, int64_t inner_dim) { - return product * output_shape.dimensions(inner_dim); + DynExpr* inner_exprs_product = absl::c_accumulate( + inner_dims, DynExpr::one, [&](DynExpr* product, int64_t inner_dim) { + return *product * *output_shape.expressions(inner_dim); }); // For each operand, emit a memcpy from the operand to the target of size @@ -3142,18 +3164,24 @@ absl::Status EmitFastConcatenate( llvm::Value* copy_source_address = source_array.EmitArrayElementAddress(source_index, &b, "src_addr"); - llvm::Value* copy_target_address = - b.CreateGEP(b.getInt8Ty(), target_region_begin, - b.getInt64(byte_offset_into_target_region)); + llvm::Value* copy_target_address = b.CreateGEP( + b.getInt8Ty(), target_region_begin, byte_offset_into_target_region); + + auto cexpr = input_shape.expressions(concat_dim); + + ::xla::cpu::EmitTransferElements(copy_target_address, copy_source_address, + (*inner_exprs_product * *cexpr)->s(), + primitive_type, target_array, source_array, + module, b); - ::xla::cpu::EmitTransferElements( - copy_target_address, copy_source_address, - inner_dims_product * input_shape.dimensions(concat_dim), primitive_type, - target_array, source_array, module, b); + llvm::Value* concat_dim_count = xla::llvm_ir::EmitExpression( + &b, (*inner_exprs_product * *input_shape.expressions(concat_dim))->s()); - byte_offset_into_target_region += inner_dims_product * - input_shape.dimensions(concat_dim) * - primitive_type_size; + llvm::Value* concat_dim_size = + b.CreateMul(concat_dim_count, b.getInt64(primitive_type_size)); + byte_offset_into_target_region = + b.CreateAdd(byte_offset_into_target_region, concat_dim_size, + "byte_offset_into_target_region"); } if (!outer_dims.empty()) { @@ -3364,7 +3392,7 @@ llvm::Value* IrEmitter::EmitCallToFfi(HloCustomCallInstruction* custom_call, } void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, - int64_t element_count, + xla::DynExpr* element_count, PrimitiveType primitive_type, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& source_array) { @@ -3374,7 +3402,8 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, } void EmitTransferElements(llvm::Value* target, llvm::Value* source, - int64_t element_count, PrimitiveType primitive_type, + xla::DynExpr* element_count, + PrimitiveType primitive_type, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& source_array, llvm::Module* module, llvm::IRBuilderBase& b) { @@ -3386,7 +3415,7 @@ void EmitTransferElements(llvm::Value* target, llvm::Value* source, llvm::Type* primitive_llvm_type = llvm_ir::PrimitiveTypeToIrType(primitive_type, module->getContext()); - if (element_count == 1) { + if (element_count == DynExpr::one) { auto* load_instruction = b.CreateAlignedLoad(primitive_llvm_type, source, element_alignment); source_array.AnnotateLoadStoreInstructionWithMetadata(load_instruction); @@ -3394,11 +3423,12 @@ void EmitTransferElements(llvm::Value* target, llvm::Value* source, b.CreateAlignedStore(load_instruction, target, element_alignment); target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction); } else { + auto element_count_value = xla::llvm_ir::EmitExpression(&b, element_count); + llvm::Value* elements_size = + b.CreateMul(element_count_value, b.getInt64(primitive_type_size)); auto* memcpy_instruction = b.CreateMemCpy( target, /*DstAlign=*/llvm::Align(element_alignment), source, - /*SrcAlign=*/llvm::Align(element_alignment), - element_count * primitive_type_size); - + /*SrcAlign=*/llvm::Align(element_alignment), elements_size); // The memcpy does the load and the store internally. The aliasing related // metadata has to reflect that. std::map merged_metadata = @@ -3911,8 +3941,10 @@ llvm::Value* IrEmitter::EmitThreadLocalBufferPointer( if (!target_shape.IsOpaque()) { AttachAlignmentMetadataForLoad(param_address_untyped, target_shape); - AttachDereferenceableMetadataForLoad(param_address_untyped, - target_shape); + if (!target_shape.has_dynamic_expr()) { + AttachDereferenceableMetadataForLoad(param_address_untyped, + target_shape); + } } return param_address_untyped; } @@ -3958,7 +3990,10 @@ llvm::Value* IrEmitter::EmitGlobalBufferPointer( AttachInvariantLoadMetadataForLoad(tempbuf_address_base); AttachAlignmentMetadataForLoad(tempbuf_address_base, allocation.size()); - AttachDereferenceableMetadataForLoad(tempbuf_address_base, allocation.size()); + + if (!target_shape.has_dynamic_expr()) + AttachDereferenceableMetadataForLoad(tempbuf_address_base, + allocation.size()); llvm::Value* tempbuf_address_untyped = tempbuf_address_base; // Any explicit buffer pointer should point to the start of the slice. @@ -4059,9 +4094,38 @@ absl::Status IrEmitter::EmitMemcpy(const HloInstruction& source, llvm::Value* source_value = GetEmittedValueFor(&source); llvm::Value* destination_value = GetEmittedValueFor(&destination); int64_t source_size = ByteSizeOf(source.shape()); - // TODO(b/63762267): Be more aggressive about specifying alignment. - MemCpy(destination_value, /*DstAlign=*/llvm::Align(1), source_value, - /*SrcAlign=*/llvm::Align(1), source_size); + auto shape = source.shape(); + auto expressions = shape.expressions(); + bool is_dynamic = + std::any_of(expressions.begin(), expressions.end(), + [](DynExpr* e) { return e->is_dynamic(); }); + if (is_dynamic) { + llvm::LLVMContext& ctx = b()->getContext(); + llvm::IntegerType* i64Type = llvm::IntegerType::getInt64Ty(ctx); + int64_t dimensions_accu = 1; + DynExpr* expression_accu = DynExpr::one; + for (int i = 0; i < shape.dimensions_size(); i++) { + auto expression = shape.expressions(i); + if (expression->is_dynamic()) { + dimensions_accu *= shape.dimensions(i); + expression_accu = (*expression_accu) * (*expression); + } + } + llvm::Value* expr_value = + xla::llvm_ir::EmitExpression(b(), expression_accu->s()); + // Divide the size in bytes by the size of the dynamic dimension(s). + // TODO: make that less hacky + llvm::ConstantInt* size = + llvm::ConstantInt::get(i64Type, source_size / dimensions_accu, true); + llvm::Value* memcopy_size = + b()->CreateMul(expr_value, size, "memcopy_size"); + MemCpy(destination_value, /*DstAlign=*/llvm::Align(1), source_value, + /*SrcAlign=*/llvm::Align(1), memcopy_size); + } else { + // TODO(b/63762267): Be more aggressive about specifying alignment. + MemCpy(destination_value, /*DstAlign=*/llvm::Align(1), source_value, + /*SrcAlign=*/llvm::Align(1), source_size); + } return absl::OkStatus(); } diff --git a/third_party/xla/xla/service/cpu/ir_emitter.h b/third_party/xla/xla/service/cpu/ir_emitter.h index 40f54d2f4bff97..4d4af3805e7b1d 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.h +++ b/third_party/xla/xla/service/cpu/ir_emitter.h @@ -332,6 +332,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, private: absl::Status HandleSliceToDynamic(HloInstruction* hlo); + absl::Status HandleOuterBatchValue(HloInstruction* hlo); absl::Status HandlePadToStatic(HloInstruction* hlo); absl::Status HandleTopK(HloInstruction* hlo) override; absl::Status HandleAllReduceSingleReplica(HloInstruction* crs); @@ -569,7 +570,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, // Emits LLVM IR to transfer "element_count" elements of type "primitive_type" // from the address "source" to the address "target". void EmitTransferElements(llvm::Value* target, llvm::Value* source, - int64_t element_count, PrimitiveType primitive_type, + xla::DynExpr* element_count, PrimitiveType primitive_type, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& source_array); @@ -859,7 +860,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, // Decoupled implementation of IrEmitter::EmitTransferElements. void EmitTransferElements(llvm::Value* target, llvm::Value* source, - int64_t element_count, PrimitiveType primitive_type, + xla::DynExpr* element_count, PrimitiveType primitive_type, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& source_array, llvm::Module* module, llvm::IRBuilderBase& b); diff --git a/third_party/xla/xla/service/cpu/ir_emitter2.cc b/third_party/xla/xla/service/cpu/ir_emitter2.cc index bce2108bb87572..8ce0ad8129a155 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter2.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter2.cc @@ -195,6 +195,27 @@ absl::StatusOr IrEmitter2::EmitPadHostKernel( KernelInfo(std::move(kernel_prototype), se::BlockDim(), se::ThreadDim())); } +absl::StatusOr +IrEmitter2::EmitGetOuterBatchValueHostKernel(const HloInstruction* getBatch) { + VLOG(2) << "Emit GetOuterBatchValue host kernel: " << getBatch->name(); + + TF_ASSIGN_OR_RETURN(KernelPrototype kernel_prototype, + EmitKernelPrototype(getBatch)); + llvm_ir::IrArray operand_array = kernel_prototype.arguments[0]; + llvm_ir::IrArray output_array = kernel_prototype.results[0]; + xla::DynExpr* expr = getBatch->operand(0)->shape().expressions(0); + llvm::IRBuilder<> b(module_->getContext()); + b.SetInsertPoint(kernel_prototype.function->getEntryBlock().getTerminator()); + llvm::Value* bdim_value = llvm_ir::EmitExpression(&b, expr); + llvm_ir::IrArray::Index output_index(/*multidimensional_index=*/{}, + getBatch->shape(), b.getInt32Ty()); + llvm::Value* output_ptr = + output_array.EmitArrayElementAddress(output_index, &b); + b.CreateStore(bdim_value, output_ptr); + return kernels_.emplace_back( + KernelInfo(std::move(kernel_prototype), se::BlockDim(), se::ThreadDim())); +} + absl::StatusOr IrEmitter2::EmitFusionHostKernel( const HloFusionInstruction* fusion) { VLOG(2) << "Emit fusion host kernel: " << fusion->name(); diff --git a/third_party/xla/xla/service/cpu/ir_emitter2.h b/third_party/xla/xla/service/cpu/ir_emitter2.h index e720e06f37642c..2a0b98d5856332 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter2.h +++ b/third_party/xla/xla/service/cpu/ir_emitter2.h @@ -113,6 +113,9 @@ class IrEmitter2 { // Emits a host kernel for the pad instruction. absl::StatusOr EmitPadHostKernel(const HloInstruction* pad); + absl::StatusOr EmitGetOuterBatchValueHostKernel( + const HloInstruction* getBatch); + // Emits a host kernel for the given fusion instruction. absl::StatusOr EmitFusionHostKernel( const HloFusionInstruction* fusion); diff --git a/third_party/xla/xla/service/cpu/parallel_loop_emitter.cc b/third_party/xla/xla/service/cpu/parallel_loop_emitter.cc index 2bfffd88df937e..e3af0399381b79 100644 --- a/third_party/xla/xla/service/cpu/parallel_loop_emitter.cc +++ b/third_party/xla/xla/service/cpu/parallel_loop_emitter.cc @@ -61,6 +61,7 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, // performance with a large improvement in compile time. auto unroll_mode = (i == 0) ? llvm_ir::UnrollMode::kDefaultUnroll : llvm_ir::UnrollMode::kNoUnroll; + if (bounds_index < dynamic_loop_bounds_->size()) { // Emit dynamic loop bounds for this dimension. Dynamic loop bounds // are read from ir function dynamic loop bounds argument. @@ -69,14 +70,17 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, std::unique_ptr loop = loop_nest.AddLoop( /*suffix=*/absl::StrFormat("dim.%d", dimension), start_index, - end_index, unroll_mode); + end_index, unroll_mode, /*prevent_vectorization*/ false, + /* expression */ shape_.expressions(dimension)); array_multi_index[dimension] = loop->GetIndVarValue(); } else { // Emit static loop bounds for this dimension. std::unique_ptr loop = loop_nest.AddLoop( /*start_index=*/0, /*end_index=*/shape_.dimensions(dimension), - /*suffix=*/absl::StrFormat("dim.%d", dimension), unroll_mode); + /*suffix=*/absl::StrFormat("dim.%d", dimension), unroll_mode, + /*prevent_vectorization*/ false, + /* expression */ shape_.expressions(dimension)); array_multi_index[dimension] = loop->GetIndVarValue(); } } diff --git a/third_party/xla/xla/service/cpu/thunk_emitter.cc b/third_party/xla/xla/service/cpu/thunk_emitter.cc index 9079ed8a16d1a8..b6c32ecf8d4eb0 100644 --- a/third_party/xla/xla/service/cpu/thunk_emitter.cc +++ b/third_party/xla/xla/service/cpu/thunk_emitter.cc @@ -1086,6 +1086,8 @@ absl::StatusOr ThunkEmitter::EmitCustomCallThunk( return EmitTopKThunk(custom_call); } else if (custom_call_target == "SliceToDynamic") { return EmitSliceToDynamicThunk(instruction); + } else if (custom_call_target == "GetOuterBatchValue") { + return EmitGetOuterBatchValueThunk(instruction); } // Check the API version. @@ -1128,6 +1130,21 @@ absl::StatusOr ThunkEmitter::EmitSliceToDynamicThunk( /*min_alignment=*/cpu_function_runtime::MinAlign()); } +absl::StatusOr ThunkEmitter::EmitGetOuterBatchValueThunk( + const HloInstruction* instruction) { + VLOG(2) << "Handling GetOuterBatchValue for instruction: " + << instruction->ToString(); + const HloCustomCallInstruction* custom_call = + Cast(instruction); + TF_ASSIGN_OR_RETURN( + auto kernel, ir_emitter_.EmitGetOuterBatchValueHostKernel(custom_call)); + TF_ASSIGN_OR_RETURN(auto result_buffer, + GetHostKernelAllocationSlices(instruction)); + return MakeKernelThunkSequence( + instruction, result_buffer, kernel, + /*min_alignment=*/cpu_function_runtime::MinAlign()); +} + absl::StatusOr ThunkEmitter::EmitSliceThunk( const HloInstruction* instruction) { // TODO(ezhulenev): Consider implementing slice operations as separate diff --git a/third_party/xla/xla/service/cpu/thunk_emitter.h b/third_party/xla/xla/service/cpu/thunk_emitter.h index 9cdc8ae3981680..9949fef0329371 100644 --- a/third_party/xla/xla/service/cpu/thunk_emitter.h +++ b/third_party/xla/xla/service/cpu/thunk_emitter.h @@ -183,6 +183,9 @@ class ThunkEmitter { absl::StatusOr EmitSliceToDynamicThunk( const HloInstruction* instruction); + absl::StatusOr EmitGetOuterBatchValueThunk( + const HloInstruction* instruction); + absl::StatusOr EmitTopKThunk( const HloCustomCallInstruction* custom_call); diff --git a/third_party/xla/xla/service/elemental_ir_emitter.cc b/third_party/xla/xla/service/elemental_ir_emitter.cc index 89391a3bdd1dd5..b2f20ebcacdf0a 100644 --- a/third_party/xla/xla/service/elemental_ir_emitter.cc +++ b/third_party/xla/xla/service/elemental_ir_emitter.cc @@ -3278,51 +3278,57 @@ absl::StatusOr ElementalIrEmitter::EmitElementalConcatenate( } // We use bisection to select the input operand. - int64_t current_offset = 0; + int64_t coffset = 0; + llvm::Value* current_offset = source_index.GetConstantWithIndexType(0); // Offset for every operand. - std::vector> cases; + std::vector> cases; cases.reserve(hlo->operand_count()); for (const HloInstruction* operand : hlo->operands()) { cases.emplace_back(current_offset, operand); - current_offset += operand->shape().dimensions(concat_dim); + llvm::Value* cdim = source_index.GetConstantWithIndexType( + operand->shape().dimensions(concat_dim)); + xla::DynExpr* concat_expr = operand->shape().expressions(concat_dim); + if (concat_expr != nullptr && concat_expr->is_dynamic()) { + cdim = llvm_ir::EmitExpression(b_, concat_expr); + } + current_offset = b_->CreateAdd(current_offset, cdim, "current_offset"); + coffset += operand->shape().dimensions(concat_dim); } - CHECK_EQ(current_offset, hlo->shape().dimensions(concat_dim)); + CHECK_EQ(coffset, hlo->shape().dimensions(concat_dim)); std::function> operands)> + absl::Span> operands)> emit_tree = - [&](absl::Span> + [&](absl::Span> operands) { llvm::IRBuilder<>::InsertPointGuard guard(*b_); size_t mid = operands.size() / 2; - const std::pair& pivot = + const std::pair& pivot = operands[mid]; llvm::BasicBlock* block = llvm_ir::CreateBasicBlock( exit_block, - absl::StrCat("concatenate.pivot.", pivot.first, "."), b_); + absl::StrCat("concatenate.pivot."), b_); b_->SetInsertPoint(block); // If there's only one element we're done. The range is contiguous // so we can just jump to the block for it. if (operands.size() == 1) { - const std::pair& operand = + const std::pair& operand = operands.back(); int64_t operand_id = to_unique_operand_id[operand.second]; source_index_phis[operand_id]->addIncoming( - source_index.GetConstantWithIndexType(operand.first), + operand.first, b_->GetInsertBlock()); b_->CreateBr(emit_operand_blocks[operand_id]); return block; } // Take the middle element and recurse. - llvm::Constant* pivot_const = llvm::ConstantInt::get( - source_index[concat_dim]->getType(), pivot.first); llvm::Value* comp = - b_->CreateICmpULT(source_index[concat_dim], pivot_const); + b_->CreateICmpULT(source_index[concat_dim], pivot.first); llvm::BasicBlock* left_block = emit_tree(operands.subspan(0, mid)); llvm::BasicBlock* right_block = emit_tree(operands.subspan(mid)); @@ -3616,11 +3622,16 @@ absl::StatusOr ElementalIrEmitter::EmitElementalPad( "in_bounds"); multi_index[i] = SDiv(multi_index[i], index_typed_const(pad_dim.interior_padding() + 1)); - in_bounds = - And(in_bounds, - ICmpSLT(multi_index[i], - index_typed_const(hlo->operand(0)->shape().dimensions(i))), - "in_bounds"); + + int64_t shape_dim = hlo->operand(0)->shape().dimensions(i); + llvm::Value* bound = index_typed_const(shape_dim); + + xla::DynExpr* operand_expr = hlo->operand(0)->shape().expressions(i); + if (operand_expr != nullptr && operand_expr->is_dynamic()) { + bound = llvm_ir::EmitExpression(b_, operand_expr); + } + + in_bounds = And(in_bounds, ICmpSLT(multi_index[i], bound), "in_bounds"); } // if (in_bounds) { @@ -3687,9 +3698,20 @@ absl::StatusOr ElementalIrEmitter::EmitElementalDot( return llvm::ConstantInt::get(index_type, c); }; + llvm::Value* contracted_bound = index_typed_const(contracted_dim_size); + + if (!hlo->operand(0) + ->shape() + .expressions(lhs_contracting_dim) + ->is_constant()) { + llvm::Value* expr_value = llvm_ir::EmitExpression( + b_, hlo->operand(0)->shape().expressions(lhs_contracting_dim)); + contracted_bound = expr_value; + } + std::unique_ptr inner_loop = llvm_ir::ForLoop::EmitForLoop( - IrName(hlo, "inner"), index_typed_const(0), - index_typed_const(contracted_dim_size), index_typed_const(1), b_); + IrName(hlo, "inner"), index_typed_const(0), contracted_bound, + index_typed_const(1), b_); SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(), b_); PrimitiveType primitive_type = hlo->shape().element_type(); @@ -3883,9 +3905,17 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const HloInstruction* operand = hlo->operand(0); std::vector source_multi_index = target_index.multidim(); for (int64_t dim : hlo->dimensions()) { - source_multi_index[dim] = Sub(target_index.GetConstantWithIndexType( - hlo->shape().dimensions(dim) - 1), - target_index[dim]); + xla::DynExpr* dim_expr = hlo->shape().expressions(dim); + if (dim_expr != nullptr && dim_expr->is_dynamic()) { + llvm::Value* one = target_index.GetConstantWithIndexType(1); + llvm::Value* expr_value = llvm_ir::EmitExpression(b_, dim_expr); + source_multi_index[dim] = + Sub(Sub(expr_value, one), target_index[dim]); + } else { + source_multi_index[dim] = Sub(target_index.GetConstantWithIndexType( + hlo->shape().dimensions(dim) - 1), + target_index[dim]); + } } llvm_ir::IrArray::Index source_index( source_multi_index, operand->shape(), target_index.GetType()); @@ -4236,11 +4266,21 @@ absl::StatusOr ElementalIrEmitter::EmitElementalReduceWindow( // comparison is equivalent to the unsigned comparison // input_multi_index[i] < bound, as a negative value wraps to a large // positive value. + + int64_t dim_bound = reduce_window->inputs()[0]->shape().dimensions(i); + llvm::Value* shape_bound = index_typed_const(dim_bound); + + xla::DynExpr* window_expr = + reduce_window->inputs()[0]->shape().expressions(i); + if (window_expr != nullptr && window_expr->is_dynamic()) { + llvm::Value* expr_value = llvm_ir::EmitExpression(b_, window_expr); + shape_bound = expr_value; + } + in_bounds = And(in_bounds, ICmpULT(input_multi_index[i], - index_typed_const( - reduce_window->inputs()[0]->shape().dimensions(i)))); + shape_bound)); } llvm_ir::LlvmIfData if_data = diff --git a/third_party/xla/xla/service/hlo_creation_utils.cc b/third_party/xla/xla/service/hlo_creation_utils.cc index b815b1ffc79627..3a33c5f89a2e36 100644 --- a/third_party/xla/xla/service/hlo_creation_utils.cc +++ b/third_party/xla/xla/service/hlo_creation_utils.cc @@ -655,9 +655,12 @@ absl::StatusOr CollapseFirstNDims(HloInstruction* operand, CHECK_GE(operand_shape.dimensions_size(), n); int64_t new_shape_leading_bound = 1; bool new_shape_leading_is_dynamic = false; + DynExpr* new_shape_leading_expression = DynExpr::one; for (int64_t i = 0; i < n; i++) { new_shape_leading_bound *= operand_shape.dimensions(i); new_shape_leading_is_dynamic |= operand_shape.is_dynamic_dimension(i); + new_shape_leading_expression = + (*new_shape_leading_expression) * (*operand_shape.expressions(i)); } std::vector new_shape_dims; @@ -675,8 +678,16 @@ absl::StatusOr CollapseFirstNDims(HloInstruction* operand, operand_shape.dynamic_dimensions().end(), std::back_inserter(new_shape_dynamic_dims)); - Shape output_shape = ShapeUtil::MakeShape( - operand_shape.element_type(), new_shape_dims, new_shape_dynamic_dims); + std::vector new_shape_expressions; + new_shape_expressions.reserve(operand_shape.dimensions_size() - n + 1); + new_shape_expressions.push_back(new_shape_leading_expression->s()); + auto exprs = operand_shape.expressions(); + std::copy(exprs.begin() + n, exprs.end(), + std::back_inserter(new_shape_expressions)); + + Shape output_shape = + ShapeUtil::MakeShape(operand_shape.element_type(), new_shape_dims, + new_shape_dynamic_dims, new_shape_expressions); return MakeReshapeHlo(output_shape, operand); } diff --git a/third_party/xla/xla/service/layout_assignment.cc b/third_party/xla/xla/service/layout_assignment.cc index b7bea1e55704c5..b5adc212b53a17 100644 --- a/third_party/xla/xla/service/layout_assignment.cc +++ b/third_party/xla/xla/service/layout_assignment.cc @@ -2571,7 +2571,7 @@ absl::Status LayoutAssignment::PropagateComputationLayouts( *result_layout = computed_computation_layout.result_layout(); } else { TF_RET_CHECK( - Shape::Equal().IgnoreDynamicDimension().MinorToMajorOnlyInLayout()( + Shape::Equal().IgnoreDynamicDimension().MinorToMajorOnlyInLayout().IgnoreBatch()( computed_computation_layout.result_layout().shape(), result_layout->shape())); } diff --git a/third_party/xla/xla/service/llvm_ir/BUILD b/third_party/xla/xla/service/llvm_ir/BUILD index 599060b88eee81..2a665c2d77f3b2 100644 --- a/third_party/xla/xla/service/llvm_ir/BUILD +++ b/third_party/xla/xla/service/llvm_ir/BUILD @@ -84,6 +84,7 @@ cc_library( "//xla/service/cpu:cpu_options", "//xla/tsl/platform:byte_order", "//xla/tsl/platform:logging", + "//xla/service/cpu:executable_run_options_offset", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/log", diff --git a/third_party/xla/xla/service/llvm_ir/ir_array.cc b/third_party/xla/xla/service/llvm_ir/ir_array.cc index 139d537c88778a..a2e4c363939504 100644 --- a/third_party/xla/xla/service/llvm_ir/ir_array.cc +++ b/third_party/xla/xla/service/llvm_ir/ir_array.cc @@ -281,8 +281,11 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( // linear index by each dimension size. for (int64_t i = common_factors[k + 1].first - 1; i >= common_factors[k].first; --i) { + xla::DynExpr* input_expr = input_shape.expressions(i); + bool is_dynamic = input_expr != nullptr && input_expr->is_dynamic(); llvm::Value* divisor = - GetConstantWithIndexType(input_shape.dimensions(i)); + is_dynamic ? llvm_ir::EmitExpression(builder, input_expr) + : GetConstantWithIndexType(input_shape.dimensions(i)); if (input_shape.dimensions(i) == 1) { source_multidim_index[i] = GetConstantWithIndexType(0); } else if (i == common_factors[k].first) { @@ -559,8 +562,25 @@ llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index, int64_t dimension = LayoutUtil::Major(shape_.layout(), i); gep_indices.push_back(actual_index[dimension]); } - return b->CreateInBoundsGEP(pointee_type_, base_ptr_, gep_indices, - llvm_ir::AsStringRef(name)); + + // Do not make a dynamic "GEP" if only the first dimension is dynamic since + // it's always indiced with 0 (i.e. the dynamic dimension has no impact on the + // address computation). + auto expressions = shape_.expressions(); + bool dynamic_first_dim = + expressions[0]->is_dynamic() && + std::all_of(expressions.begin() + 1, expressions.end(), + [](DynExpr* e) { return e->is_constant(); }); + if (!dynamic_first_dim && shape_.has_dynamic_expr()) { + llvm::Type* element_type = + PrimitiveTypeToIrType(shape_.element_type(), b->getContext()); + return llvm_ir::createDynamicGEP( + b, base_ptr_, gep_indices, shape_.dimensions(), expressions, + element_type, llvm_ir::AsStringRef(name)); + } else { + return b->CreateInBoundsGEP(pointee_type_, base_ptr_, gep_indices, + llvm_ir::AsStringRef(name)); + } } llvm::Value* IrArray::EmitLinearArrayElementAddress( diff --git a/third_party/xla/xla/service/llvm_ir/llvm_loop.cc b/third_party/xla/xla/service/llvm_ir/llvm_loop.cc index 43ee77ef0e85be..562260717ba35d 100644 --- a/third_party/xla/xla/service/llvm_ir/llvm_loop.cc +++ b/third_party/xla/xla/service/llvm_ir/llvm_loop.cc @@ -37,6 +37,7 @@ limitations under the License. #include "xla/service/llvm_ir/llvm_util.h" #include "xla/shape.h" #include "xla/tsl/platform/logging.h" +#include "llvm/include/llvm/Support/Debug.h" namespace xla { namespace llvm_ir { @@ -185,24 +186,32 @@ llvm::BasicBlock* ForLoop::CreateLoopBB(absl::string_view name, return CreateBasicBlock(insert_before_bb_, GetQualifiedName(name), b); } -std::unique_ptr ForLoopNest::AddLoop(absl::string_view suffix, - llvm::Value* start_index, - llvm::Value* end_index, - UnrollMode unroll_mode, - bool prevent_vectorization) { +std::unique_ptr ForLoopNest::AddLoop( + absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, + UnrollMode unroll_mode, bool prevent_vectorization, + DynExpr* expression) { return AddLoop(suffix, start_index, end_index, GetConstantWithIndexType(1), - unroll_mode, prevent_vectorization); + unroll_mode, prevent_vectorization, expression); } std::unique_ptr ForLoopNest::AddLoop( absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, - llvm::Value* stride, UnrollMode unroll_mode, bool prevent_vectorization) { + llvm::Value* stride, UnrollMode unroll_mode, bool prevent_vectorization, + DynExpr* expression) { if (inner_loop_body_bb_ != nullptr) { // Create this loop inside the previous one. b_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt()); } + llvm::Value* actual_end = end_index; + if (expression && expression->is_dynamic()) { + // Get batch dim and compare with end_index to use minimum value + llvm::Value* expr_value = + llvm_ir::EmitExpression(b_, expression); + actual_end = b_->CreateSelect(b_->CreateICmpULT(end_index, expr_value), + end_index, expr_value, "loop_end_min"); + } std::unique_ptr loop(new ForLoop( - /*prefix=*/name_, suffix, start_index, end_index, stride, unroll_mode, + /*prefix=*/name_, suffix, start_index, actual_end, stride, unroll_mode, prevent_vectorization)); loop->Emit(b_); @@ -219,25 +228,31 @@ std::unique_ptr ForLoopNest::AddLoop( return loop; } -std::unique_ptr ForLoopNest::AddLoop(int64_t start_index, - int64_t end_index, - absl::string_view suffix, - UnrollMode unroll_mode, - bool prevent_vectorization) { +std::unique_ptr ForLoopNest::AddLoop( + int64_t start_index, int64_t end_index, absl::string_view suffix, + UnrollMode unroll_mode, bool prevent_vectorization, + DynExpr* expression) { CHECK_LE(start_index, end_index); - return AddLoop(suffix, GetConstantWithIndexType(start_index), - GetConstantWithIndexType(end_index), unroll_mode, - prevent_vectorization); + + llvm::Value* end = (expression && expression->is_dynamic()) + ? EmitExpression(b_, expression) + : GetConstantWithIndexType(end_index); + return AddLoop(suffix, GetConstantWithIndexType(start_index), end, + unroll_mode, prevent_vectorization); } std::unique_ptr ForLoopNest::AddLoop(int64_t start_index, int64_t end_index, int64_t stride, absl::string_view suffix, UnrollMode unroll_mode, - bool prevent_vectorization) { + bool prevent_vectorization, + DynExpr* expression) { CHECK_LE(start_index, end_index); - return AddLoop(suffix, GetConstantWithIndexType(start_index), - GetConstantWithIndexType(end_index), + + llvm::Value* end = (expression && expression->is_dynamic()) + ? EmitExpression(b_, expression) + : GetConstantWithIndexType(end_index); + return AddLoop(suffix, GetConstantWithIndexType(start_index), end, GetConstantWithIndexType(stride), unroll_mode, prevent_vectorization); } @@ -259,7 +274,9 @@ std::vector ForLoopNest::AddLoopsForShapeOnDimensions( /*start_index=*/0, /*end_index=*/shape.dimensions(dimension), /*suffix=*/ - llvm_ir::IrName(suffix, absl::StrCat(dimension))); + llvm_ir::IrName(suffix, absl::StrCat(dimension)), + /*unroll_mode=*/llvm_ir::UnrollMode::kDefaultUnroll, + /*prevent_vectorization=*/false, shape.expressions(dimension)); multi_index[dimension] = loop->GetIndVarValue(); } return multi_index; diff --git a/third_party/xla/xla/service/llvm_ir/llvm_loop.h b/third_party/xla/xla/service/llvm_ir/llvm_loop.h index 7aa8ce9e32e950..5414f019121d7d 100644 --- a/third_party/xla/xla/service/llvm_ir/llvm_loop.h +++ b/third_party/xla/xla/service/llvm_ir/llvm_loop.h @@ -201,14 +201,14 @@ class ForLoopNest { absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* stride, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, - bool prevent_vectorization = false); + bool prevent_vectorization = false, DynExpr* expression = nullptr); // Like the above, except that it defaults to a stride of one. std::unique_ptr AddLoop( absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, - bool prevent_vectorization = false); + bool prevent_vectorization = false, DynExpr* expression = nullptr); // A convenient wrapper of the other flavor of AddLoop. The given start and // end index are constant. @@ -216,13 +216,13 @@ class ForLoopNest { int64_t start_index, int64_t end_index, int64_t stride, absl::string_view suffix, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, - bool prevent_vectorization = false); + bool prevent_vectorization = false, DynExpr* expression = nullptr); // Like the above, except that it defaults to a stride of one. std::unique_ptr AddLoop( int64_t start_index, int64_t end_index, absl::string_view suffix, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, - bool prevent_vectorization = false); + bool prevent_vectorization = false, DynExpr* expression = nullptr); // Add loops to iterate through the indices within the specified // shape. The returned index collects the induction variables of the diff --git a/third_party/xla/xla/service/llvm_ir/llvm_util.cc b/third_party/xla/xla/service/llvm_ir/llvm_util.cc index b1db780705b41f..7790714743fe79 100644 --- a/third_party/xla/xla/service/llvm_ir/llvm_util.cc +++ b/third_party/xla/xla/service/llvm_ir/llvm_util.cc @@ -80,11 +80,15 @@ limitations under the License. #include "xla/xla_data.pb.h" #include "tsl/profiler/lib/scoped_annotation.h" +#include "xla/service/cpu/executable_run_options_offset.h" + namespace xla { namespace llvm_ir { namespace { +constexpr llvm::StringLiteral kBdimValueName("bdim_value"); + // This works for most llvm / mlir types. This also accepts a const pointer to // objects which have a const print() method. template @@ -288,9 +292,13 @@ llvm::Type* ShapeToIrType(const Shape& shape, llvm::LLVMContext& context) { result_type = llvm::ArrayType::get(result_type, shape.tuple_shapes().size()); } else if (shape.IsArray()) { - for (int64_t dimension : LayoutUtil::MinorToMajor(shape)) { - result_type = - llvm::ArrayType::get(result_type, shape.dimensions(dimension)); + auto dimensions = LayoutUtil::MinorToMajor(shape); + for (int i = 0; i < dimensions.size(); i++) { + // The MinorToMajor order reverses dimensions... + bool is_dynamic = + shape.expressions(dimensions.size() - 1 - i)->is_dynamic(); + int64_t dim_val = is_dynamic ? 0 : shape.dimensions(dimensions[i]); + result_type = llvm::ArrayType::get(result_type, dim_val); } } return result_type; @@ -811,5 +819,136 @@ void EmitEarlyReturn(llvm::Value* condition, llvm::IRBuilderBase* b, b->SetInsertPoint(continued, continued->getFirstInsertionPt()); } +llvm::Value* GetBatchDimByName(llvm::IRBuilderBase* b, int64_t multiplier, + int64_t offset) { + llvm::Function* function = b->GetInsertBlock()->getParent(); + llvm::LLVMContext& ctx = b->getContext(); + llvm::IntegerType* i64Type = llvm::IntegerType::getInt64Ty(ctx); + llvm::Value* loadedValue = nullptr; + llvm::Value* bdim_scaled = nullptr; + for (auto& inst : function->getEntryBlock()) { + if (inst.getName() == kBdimValueName) { + loadedValue = &inst; + } + } + if (!loadedValue) { + // if missing, try to materialize it by loading from run_options+offset + llvm::Value* run_options = nullptr; + for (auto& arg : function->args()) { + if (arg.getName() == "run_options") { + run_options = &arg; + break; + } + } + + if (!run_options) { + return nullptr; + } + // Materialize %bdim_value by loading ExecutableRunOptions::batch_size_ from + // the 'run_options' function argument using a known byte offset: + // + // 1) Bitcast 'run_options' to i8* to do byte-address arithmetic. + // 2) GEP by 'off' bytes to reach the batch_size field inside the object. + // 3) Bitcast resulting i8* to i64* and load it. + const int64_t off = + static_cast(xla::cpu::ExecutableRunOptionsBatchSizeOffset()); + + // Insert at the entry block. + llvm::IRBuilder<> entry_builder( + &function->getEntryBlock(), + function->getEntryBlock().getFirstInsertionPt()); + llvm::Type* i8 = entry_builder.getInt8Ty(); + llvm::Value* ro_i8 = entry_builder.CreateBitCast( + run_options, llvm::PointerType::getUnqual(i8), "run_options_i8"); + llvm::Value* off_c = llvm::ConstantInt::get(i64Type, off); + llvm::Value* bdim_ptr_i8 = + entry_builder.CreateInBoundsGEP(i8, ro_i8, off_c, "bdim_ptr_i8"); + llvm::Value* bdim_ptr = entry_builder.CreateBitCast( + bdim_ptr_i8, llvm::PointerType::getUnqual(i64Type), "bdim_ptr"); + loadedValue = entry_builder.CreateLoad(i64Type, bdim_ptr, kBdimValueName); + } + if (multiplier < 1) { + llvm::errs() << "Multiplier is less than 1, this should not happen.\n"; + } else if (multiplier == 1) { + bdim_scaled = loadedValue; + } else { + llvm::ConstantInt* m = llvm::ConstantInt::get(i64Type, multiplier, true); + bdim_scaled = b->CreateMul(loadedValue, m, "bdim_scaled"); + } + if (offset != 0){ + llvm::ConstantInt* offset_value = + llvm::ConstantInt::get(i64Type, offset, true); + bdim_scaled = b->CreateAdd(bdim_scaled, offset_value, "bdim_offset"); + } + return bdim_scaled; +} + +llvm::Value* EmitExpression(llvm::IRBuilderBase* b, DynExpr* expr) { + llvm::Function* function = b->GetInsertBlock()->getParent(); + llvm::LLVMContext& ctx = b->getContext(); + llvm::IntegerType* i64Type = llvm::IntegerType::getInt64Ty(ctx); + if (expr == nullptr) return nullptr; + if (expr->is_constant()) + return llvm::ConstantInt::get(i64Type, expr->get_val(), true); + if (Variable* var_node = dynamic_cast(expr)) { + // For now we can just use %bdim... + return GetBatchDimByName(b); + } + if (Mul* mul_node = dynamic_cast(expr)) { + llvm::Value* v_lhs = EmitExpression(b, mul_node->get_lhs()); + llvm::Value* v_rhs = EmitExpression(b, mul_node->get_rhs()); + return b->CreateMul(v_lhs, v_rhs, "mul_dims"); + } + // TODO: Check if this should ever happen + if (Div* div_node = dynamic_cast(expr)) { + llvm::Value* v_lhs = EmitExpression(b, div_node->get_lhs()); + llvm::Value* v_rhs = EmitExpression(b, div_node->get_rhs()); + return b->CreateUDiv(v_lhs, v_rhs, "div_dims"); + } + if (Add* add_node = dynamic_cast(expr)) { + llvm::Value* v_lhs = EmitExpression(b, add_node->get_lhs()); + llvm::Value* v_rhs = EmitExpression(b, add_node->get_rhs()); + return b->CreateAdd(v_lhs, v_rhs, "add_dims"); + } + if (Sub* sub_node = dynamic_cast(expr)) { + llvm::Value* v_lhs = EmitExpression(b, sub_node->get_lhs()); + llvm::Value* v_rhs = EmitExpression(b, sub_node->get_rhs()); + return b->CreateSub(v_lhs, v_rhs, "sub_dims"); + } + return nullptr; +} + +llvm::Value* createDynamicGEP(llvm::IRBuilderBase* builder, + llvm::Value* base_ptr, + const std::vector& indices, + absl::Span dims, + absl::Span expressions, + llvm::Type* elem_type, + const llvm::Twine& name) { + llvm::Value* total_index = builder->getInt64(0); + llvm::Type* int64_ty = builder->getInt64Ty(); + + for (size_t i = 0; i < indices.size(); ++i) { + // The stride is the product of all dimensions to the right of this index. + llvm::Value* stride = builder->getInt64(1); + for (size_t j = i; j < dims.size(); ++j) { + if (expressions[j]->is_dynamic()) { + llvm::Value* expr_value = + EmitExpression(builder, expressions[j]); + stride = builder->CreateMul(stride, expr_value, "stride.dyn"); + } else { + stride = builder->CreateMul( + stride, llvm::ConstantInt::get(int64_ty, dims[j]), "stride.static"); + } + } + llvm::Value* scaled_index = + builder->CreateMul(indices[i], stride, "idx.scaled"); + total_index = builder->CreateAdd(total_index, scaled_index, "idx.total"); + } + + // Final GEP: result = base + total_index * sizeof(elem_type) + return builder->CreateGEP(elem_type, base_ptr, total_index, name); +} + } // namespace llvm_ir } // namespace xla diff --git a/third_party/xla/xla/service/llvm_ir/llvm_util.h b/third_party/xla/xla/service/llvm_ir/llvm_util.h index 88c1287d2f236d..c868c81574882f 100644 --- a/third_party/xla/xla/service/llvm_ir/llvm_util.h +++ b/third_party/xla/xla/service/llvm_ir/llvm_util.h @@ -334,6 +334,19 @@ llvm::BasicBlock* EmitReturnBlock(llvm::IRBuilderBase* b); void EmitEarlyReturn(llvm::Value* condition, llvm::IRBuilderBase* b, llvm::BasicBlock* return_block = nullptr); +llvm::Value* GetBatchDimByName(llvm::IRBuilderBase* b, int64_t multiplier = 1, + int64_t offset = 0); + +llvm::Value* createDynamicGEP(llvm::IRBuilderBase* builder, + llvm::Value* base_ptr, + const std::vector& indices, + absl::Span dims, + absl::Span expressions, + llvm::Type* elem_type, + const llvm::Twine& name = ""); + +llvm::Value* EmitExpression(llvm::IRBuilderBase* b, DynExpr* expr); + } // namespace llvm_ir } // namespace xla diff --git a/third_party/xla/xla/service/llvm_ir/loop_emitter.cc b/third_party/xla/xla/service/llvm_ir/loop_emitter.cc index 13f6a67764b346..4156e37bf4c5dd 100644 --- a/third_party/xla/xla/service/llvm_ir/loop_emitter.cc +++ b/third_party/xla/xla/service/llvm_ir/loop_emitter.cc @@ -28,11 +28,13 @@ limitations under the License. #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "xla/layout_util.h" #include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/llvm_loop.h" +#include "xla/service/llvm_ir/llvm_util.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tsl/platform/errors.h" @@ -183,6 +185,33 @@ std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( ForLoopNest loop_nest(loop_name, b_); + llvm::LLVMContext& ctx = b_->getContext(); + llvm::IntegerType* i64Type = llvm::IntegerType::getInt64Ty(ctx); + + llvm::PointerType* ptr = llvm::PointerType::getUnqual(ctx); + llvm::StructType* callFrameTy = llvm::StructType::create( + "XLA_CPU_KernelArg", ptr, ptr, i64Type, ptr, i64Type); + + std::vector dynamic_dims; + for (auto dim : shape_.dimensions()) { + dynamic_dims.push_back(llvm::ConstantInt::get(i64Type, dim)); + } + + bool dynamic = false; + for (int i = 0; i < shape_.dimensions_size(); i++) { + auto expr = shape_.expressions(i); + if (expr != nullptr && expr->is_dynamic()) { + dynamic_dims[i] = xla::llvm_ir::EmitExpression(b_, expr); + shape_.set_dynamic_dimension(i, true); + dynamic = true; + } + } + + if (dynamic) { + // Assign dynamic batch + dynamic_dims_ = dynamic_dims; + } + IrArray::Index array_index = dynamic_dims_.empty() ? EmitStaticIndex(&loop_nest, index_type) : EmitDynamicIndex(&loop_nest, index_type); diff --git a/third_party/xla/xla/service/llvm_ir/tuple_ops.cc b/third_party/xla/xla/service/llvm_ir/tuple_ops.cc index b4b3fa30affbcb..14bf83174ca269 100644 --- a/third_party/xla/xla/service/llvm_ir/tuple_ops.cc +++ b/third_party/xla/xla/service/llvm_ir/tuple_ops.cc @@ -104,7 +104,7 @@ llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64_t index, llvm::LoadInst* src_buffer = b->CreateLoad(element_pointee_type, element_ptr); // Mark the loaded pointer as dereferenceable if we know its shape. - if (!target_shape.IsOpaque()) { + if (!target_shape.IsOpaque() && !target_shape.has_dynamic_expr()) { SetDereferenceableMetadataForLoad( src_buffer, ByteSizeOf(target_shape, src_buffer->getModule()->getDataLayout())); diff --git a/third_party/xla/xla/service/reduce_scatter_combiner.cc b/third_party/xla/xla/service/reduce_scatter_combiner.cc index a86271036f6baa..9b889efb57e3a6 100644 --- a/third_party/xla/xla/service/reduce_scatter_combiner.cc +++ b/third_party/xla/xla/service/reduce_scatter_combiner.cc @@ -131,9 +131,9 @@ absl::Status CombineReduceScatters( std::swap((*perm)[most_frequent_dim], (*perm)[rs->scatter_dimension()]); // Bitcast operand and update output shape. + auto sh = ShapeUtil::PermuteDimensions(*perm, operand_shape), operand; operands.back() = - computation.AddInstruction(HloInstruction::CreateBitcast( - ShapeUtil::PermuteDimensions(*perm, operand_shape), operand)); + computation.AddInstruction(HloInstruction::CreateBitcast(sh)); output_shapes.back() = ShapeUtil::PermuteDimensions(*perm, hlo->shape()); } } diff --git a/third_party/xla/xla/service/shape_inference.cc b/third_party/xla/xla/service/shape_inference.cc index 7985c930d812b5..2d0cc3b27cc123 100644 --- a/third_party/xla/xla/service/shape_inference.cc +++ b/third_party/xla/xla/service/shape_inference.cc @@ -195,6 +195,7 @@ absl::StatusOr InferWindowOutputShape(const Shape& base_shape, std::vector output_dimensions(window.dimensions_size()); std::vector output_is_dynamic(window.dimensions_size()); + std::vector output_expressions(window.dimensions_size()); for (int64_t i = 0; i < window.dimensions_size(); ++i) { const auto& dim = window.dimensions(i); if (dim.size() <= 0) { @@ -230,10 +231,12 @@ absl::StatusOr InferWindowOutputShape(const Shape& base_shape, padded_dilated_base, dilated_window, dim.stride()); } output_is_dynamic[i] = base_shape.is_dynamic_dimension(i); + output_expressions[i] = base_shape.expressions(i); } return ShapeUtil::MakeValidatedShape(element_type, output_dimensions, - output_is_dynamic); + output_is_dynamic, + output_expressions); } // Encapsulates inferred dimension size and bound size. @@ -472,6 +475,7 @@ absl::StatusOr InferMostSpecificDimAndBound(int64_t dim, int64_t last_dim = operand_shape.dimensions_size() - 1; std::vector is_dynamic(operand_shape.dimensions_size()); std::vector dimensions(operand_shape.dimensions_size()); + std::vector expressions(operand_shape.dimensions_size()); TF_RET_CHECK(operand_shape.dimensions(last_dim) >= k) << "k=" << k << " is larger than the last dimension of size=" @@ -480,10 +484,13 @@ absl::StatusOr InferMostSpecificDimAndBound(int64_t dim, is_dynamic[i] = i == last_dim ? false : operand_shape.is_dynamic_dimension(i); dimensions[i] = i == last_dim ? k : operand_shape.dimensions(i); + expressions[i] = + i == last_dim ? xla::DynExpr::_(k) : operand_shape.expressions(i); } - Shape out = ShapeUtil::MakeShape(operand_shape.element_type(), dimensions, - is_dynamic); + Shape out = + ShapeUtil::MakeShape(operand_shape.element_type(), dimensions, is_dynamic, + expressions); Shape idxs_shape = ShapeUtil::ChangeElementType(out, PrimitiveType::S32); return ShapeUtil::MakeTupleShape({out, idxs_shape}); } @@ -542,9 +549,11 @@ absl::StatusOr InferMostSpecificDimAndBound(int64_t dim, int64_t rank = arg_shape->dimensions_size(); std::vector inferred_sizes(rank, Shape::kUnboundedSize); std::vector inferred_bounds(rank, Shape::kUnboundedSize); + std::vector inferred_expressions(rank, DynExpr::zero); // Note: for the concatenate dimension, 0 should be the identity element: // Any dim size can keep unchanged when concatenated with 0 inferred_sizes[dimension] = 0; + inferred_expressions[dimension] = DynExpr::zero; for (const Shape* shape : arg_shapes) { for (int dim = 0; dim < rank; ++dim) { @@ -554,24 +563,32 @@ absl::StatusOr InferMostSpecificDimAndBound(int64_t dim, int64_t leftSize = inferred_sizes[dim]; int64_t rightSize = dimension_size; int64_t leftBound = inferred_bounds[dim]; + xla::DynExpr* leftExpression = inferred_expressions[dim]; int64_t rightBound = shape->is_dynamic_dimension(dim) ? dimension_size : Shape::kUnboundedSize; + xla::DynExpr* rightExpression = shape->expressions(dim); + xla::DynExpr* inferred_expression = xla::DynExpr::zero; + if (dim == dimension) { inferred_dim_and_bound = InferConcatenatedDimAndBound( leftSize, rightSize, leftBound, rightBound); + inferred_expression = *leftExpression + *rightExpression; } else { TF_ASSIGN_OR_RETURN( inferred_dim_and_bound, InferMostSpecificDimAndBound(dim, leftSize, rightSize, leftBound, rightBound)); + inferred_expression = rightExpression; } inferred_sizes[dim] = inferred_dim_and_bound.dimension; inferred_bounds[dim] = inferred_dim_and_bound.bound; + inferred_expressions[dim] = inferred_expression->s(); } } - Shape result = ShapeUtil::MakeShape(element_type, inferred_sizes); + Shape result = + ShapeUtil::MakeShape(element_type, inferred_sizes, inferred_expressions); for (int64_t i = 0; i < inferred_bounds.size(); ++i) { if (!IsUnboundedDynamicSize(inferred_bounds[i]) || IsUnboundedDynamicSize(inferred_sizes[i])) { @@ -756,6 +773,7 @@ absl::StatusOr InferMostSpecificDimAndBound(int64_t dim, std::vector dimensions(operand_shape.dimensions_size()); std::vector is_dynamic(operand_shape.dimensions_size()); + std::vector expressions(operand_shape.dimensions_size()); for (int64_t i = 0; i < operand_shape.dimensions_size(); ++i) { const auto& p = padding_config.dimensions(i); if (operand_shape.is_unbounded_dynamic_dimension(i)) { @@ -771,11 +789,13 @@ absl::StatusOr InferMostSpecificDimAndBound(int64_t dim, } } is_dynamic[i] = operand_shape.is_dynamic_dimension(i); + auto diff = dimensions[i] - operand_shape.dimensions(i); + expressions[i] = (*operand_shape.expressions(i) + diff)->s(); } return ShapeUtil::MakeShape( ShapeUtil::HigherPrecisionElementType(operand_shape, padding_value_shape), - dimensions, is_dynamic); + dimensions, is_dynamic, expressions); } // Current DotDimensionNumbers Requirements: @@ -920,7 +940,9 @@ absl::Status CheckDotDimensionConstraints( void GenerateDotResultDimensions( const Shape& lhs, const Shape& rhs, const DotDimensionNumbers& dimension_numbers, - std::vector& dimensions, std::vector& is_dynamic, + std::vector& dimensions, + std::vector& expressions, + std::vector& is_dynamic, std::vector rhs_group_dimensions = {}) { const auto& lhs_batch_dimensions = dimension_numbers.lhs_batch_dimensions(); const auto lhs_batch_dimensions_size = @@ -930,9 +952,11 @@ void GenerateDotResultDimensions( dimension_numbers.rhs_contracting_dimensions().size() - dimension_numbers.rhs_batch_dimensions().size(); dimensions.reserve(lhs_batch_dimensions_size); + expressions.reserve(lhs_batch_dimensions_size); is_dynamic.reserve(lhs_batch_dimensions_size); for (const int64_t lhs_dim : lhs_batch_dimensions) { dimensions.push_back(lhs.dimensions(lhs_dim)); + expressions.push_back(lhs.expressions(lhs_dim)); is_dynamic.push_back(lhs.is_dynamic_dimension(lhs_dim)); } for (int64_t i = 0; i < lhs.dimensions_size(); i++) { @@ -940,6 +964,7 @@ void GenerateDotResultDimensions( i) && !absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) { dimensions.push_back(lhs.dimensions(i)); + expressions.push_back(lhs.expressions(i)); is_dynamic.push_back(lhs.is_dynamic_dimension(i)); } } @@ -949,6 +974,7 @@ void GenerateDotResultDimensions( !absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i) && !absl::c_linear_search(rhs_group_dimensions, i)) { dimensions.push_back(rhs.dimensions(i)); + expressions.push_back(rhs.expressions(i)); is_dynamic.push_back(rhs.is_dynamic_dimension(i)); } } @@ -990,12 +1016,14 @@ void GenerateDotResultDimensions( std::vector dimensions; std::vector is_dynamic; + std::vector expressions; GenerateDotResultDimensions(lhs, rhs, dimension_numbers, dimensions, - is_dynamic); + expressions, is_dynamic); PrimitiveType type = preferred_element_type.value_or( ShapeUtil::HigherPrecisionElementType(lhs, rhs)); - Shape result = ShapeUtil::MakeShape(type, dimensions, is_dynamic); + Shape result = + ShapeUtil::MakeShape(type, dimensions, is_dynamic, expressions); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result)); VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result); @@ -1193,6 +1221,7 @@ void GenerateDotResultDimensions( PrimitiveType type = preferred_element_type.value_or( ShapeUtil::HigherPrecisionElementType(lhs, rhs)); std::vector dimensions; + std::vector expressions; std::vector is_dynamic; // Add the group dimension to the result shape in case of ragged contracting. if (mode == kContracting) { @@ -1200,9 +1229,10 @@ void GenerateDotResultDimensions( is_dynamic.push_back(is_dynamic_group_sizes); } GenerateDotResultDimensions(lhs, rhs, dimension_numbers, dimensions, - is_dynamic, rhs_group_dimensions); + expressions, is_dynamic, rhs_group_dimensions); - Shape result = ShapeUtil::MakeShape(type, dimensions, is_dynamic); + Shape result = + ShapeUtil::MakeShape(type, dimensions, is_dynamic, expressions); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result)); VLOG(2) << "inferred ragged dot shape: " << ShapeUtil::HumanString(result); return result; @@ -1248,12 +1278,15 @@ void GenerateDotResultDimensions( // Build the resulting shape dimensions. std::vector dimensions; std::vector is_dynamic; + std::vector expressions; for (int64_t i = 0; i < operand_shape.dimensions_size(); ++i) { dimensions.push_back(i != sparsity.dimension() ? operand_shape.dimensions(i) : metadata_dimension_size); is_dynamic.push_back(operand_shape.is_dynamic_dimension(i)); + expressions.push_back(operand_shape.expressions(i)); } - return ShapeUtil::MakeShape(element_type, dimensions, is_dynamic); + return ShapeUtil::MakeShape(element_type, dimensions, is_dynamic, + expressions); } /* static */ absl::StatusOr @@ -1267,6 +1300,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(const Shape& lhs, // from the lhs/rhs pair in every index. std::vector output_dimensions(lhs.dimensions_size()); std::vector output_dimensions_is_dynamic(lhs.dimensions_size()); + std::vector output_dimensions_expressions(lhs.dimensions_size()); for (int64_t i = 0; i < lhs.dimensions_size(); ++i) { if (lhs.dimensions(i) == 1 || rhs.dimensions(i) == 1) { // For the unbounded case, the operand with 1 should be broadcasted to the @@ -1283,7 +1317,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(const Shape& lhs, output_dimensions_is_dynamic[i] = lhs.dimensions(i) == 1 ? rhs.is_dynamic_dimension(i) : lhs.is_dynamic_dimension(i); - } else if (lhs.dimensions(i) == rhs.dimensions(i)) { + output_dimensions_expressions[i] = lhs.dimensions(i) == 1 + ? rhs.expressions(i) + : lhs.expressions(i); + } else if (lhs.dimensions(i) == rhs.dimensions(i)) { // && + // *lhs.expressions(i) == *rhs.expressions(i)) { // LHS | RHS | Result // X | X | X // X | <=X | <=X @@ -1293,6 +1331,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(const Shape& lhs, output_dimensions[i] = lhs.dimensions(i); output_dimensions_is_dynamic[i] = lhs.is_dynamic_dimension(i) || rhs.is_dynamic_dimension(i); + output_dimensions_expressions[i] = lhs.expressions(i); } else if (lhs.is_unbounded_dynamic_dimension(i) || rhs.is_unbounded_dynamic_dimension(i)) { // For the last two rows, consider when <=X turns out to be 1 and ? turns @@ -1309,6 +1348,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(const Shape& lhs, output_dimensions_is_dynamic[i] = lhs.is_unbounded_dynamic_dimension(i) ? rhs.is_dynamic_dimension(i) : lhs.is_dynamic_dimension(i); + output_dimensions_expressions[i] = lhs.is_unbounded_dynamic_dimension(i) + ? rhs.expressions(i) + : lhs.expressions(i); } else { return InvalidArgument("Binary op with incompatible shapes: %s and %s.", ShapeUtil::HumanString(lhs), @@ -1317,7 +1359,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(const Shape& lhs, } return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs), - output_dimensions, output_dimensions_is_dynamic); + output_dimensions, output_dimensions_is_dynamic, + output_dimensions_expressions); } /* static */ absl::StatusOr ShapeInference::InferInDimBroadcastShape( @@ -1398,6 +1441,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(const Shape& lhs, dimension_to_match, larger_shape.dimensions_size())); } int64_t small_dimension_size = smaller_shape.dimensions(i); + DynExpr* small_dimension_exp = smaller_shape.expressions(i); int64_t large_dimension_size = larger_shape.dimensions(dimension_to_match); bool small_is_dynamic = smaller_shape.is_dynamic_dimension(i); bool large_is_dynamic = @@ -1436,6 +1480,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(const Shape& lhs, output_shape.set_dimensions(dimension_to_match, small_dimension_size, small_is_dynamic); + output_shape.set_expression(dimension_to_match, small_dimension_exp); } return output_shape; @@ -1779,7 +1824,9 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { output_shape.element_type(), arg_shape->dimensions(), /*dynamic_dimensions=*/ std::vector(arg_shape->dynamic_dimensions().begin(), - arg_shape->dynamic_dimensions().end())); + arg_shape->dynamic_dimensions().end()), + /*expressions=*/ + arg_shape->expressions()); } /* static */ absl::StatusOr ShapeInference::InferBatchNormTrainingShape( @@ -1864,8 +1911,12 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { const int64_t feature_count = operand_shape.dimensions(feature_index); bool dynamic_feature = operand_shape.is_dynamic_dimension(feature_index); - Shape output_shape_for_mean_and_var = ShapeUtil::MakeShape( - operand_shape.element_type(), {feature_count}, {dynamic_feature}); + DynExpr* expression_feature = + operand_shape.expressions(feature_index); + + Shape output_shape_for_mean_and_var = + ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count}, + {dynamic_feature}, {expression_feature}); if (!CompatibleDimensionSizes(ShapeUtil::GetDimension(offset_shape, 0), feature_count)) { @@ -2148,8 +2199,12 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { const int64_t feature_count = operand_shape.dimensions(feature_index); bool dynamic_feature = operand_shape.is_dynamic_dimension(feature_index); + DynExpr* expression_feature = + operand_shape.expressions(feature_index); + Shape feature_shape = ShapeUtil::MakeShape( - operand_shape.element_type(), {feature_count}, {dynamic_feature}); + operand_shape.element_type(), {feature_count}, {dynamic_feature}, + {expression_feature}); if (!CompatibleDimensionSizes(ShapeUtil::GetDimension(mean_shape, 0), feature_count)) { @@ -2398,13 +2453,17 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { } std::vector dynamic_dimensions(input_spatial_dims.size()); + std::vector expressions(input_spatial_dims.size()); for (auto it = input_spatial_dims.begin(); it != input_spatial_dims.end(); ++it) { dynamic_dimensions[it - input_spatial_dims.begin()] = IsUnboundedDynamicSize(*it); + expressions[it - input_spatial_dims.begin()] = + DynExpr::_(-70); } Shape base_shape = ShapeUtil::MakeShape( - lhs.element_type(), input_spatial_dims, dynamic_dimensions); + lhs.element_type(), input_spatial_dims, dynamic_dimensions, + expressions); TF_ASSIGN_OR_RETURN( Shape window_output_shape, InferWindowOutputShape(base_shape, window, lhs.element_type())); @@ -2457,7 +2516,8 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { } PrimitiveType type = preferred_element_type.value_or( ShapeUtil::HigherPrecisionElementType(lhs, rhs)); - return ShapeUtil::MakeShape(type, dimensions, is_dynamic); + return ShapeUtil::MakeShape(type, dimensions, is_dynamic, + expressions); } /* static */ absl::StatusOr ShapeInference::InferFftShape( @@ -2780,8 +2840,10 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { const std::vector dynamic_dimensions(shape.dynamic_dimensions().begin(), shape.dynamic_dimensions().end()); + auto exprs = shape.expressions(); + std::vector expressions(exprs.begin(), exprs.end()); return ShapeUtil::MakeShape(shape.element_type(), new_dimensions, - dynamic_dimensions); + dynamic_dimensions, expressions); } /* static */ absl::StatusOr ShapeInference::InferAllToAllTupleShape( @@ -2923,23 +2985,28 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { std::vector new_dimensions; std::vector new_is_dynamic; + std::vector new_expressions; for (int i = 0; i < arg.dimensions_size(); ++i) { if (dimensions_to_reduce_set.find(i) == dimensions_to_reduce_set.end()) { new_dimensions.push_back(arg.dimensions(i)); new_is_dynamic.push_back(arg.is_dynamic_dimension(i)); + new_expressions.push_back(arg.expressions(i)); } } if (ShapeUtil::IsScalar(to_apply.result())) { - return ShapeUtil::MakeShape(to_apply.result().element_type(), - new_dimensions, new_is_dynamic); + return ShapeUtil::MakeShape( + to_apply.result().element_type(), new_dimensions, new_is_dynamic, + new_expressions); } else { std::vector result_subshapes; const auto& tuple_shapes = to_apply.result().tuple_shapes(); result_subshapes.reserve(tuple_shapes.size()); for (const Shape& subshape : tuple_shapes) { - result_subshapes.push_back(ShapeUtil::MakeShape( - subshape.element_type(), new_dimensions, new_is_dynamic)); + auto new_shape = ShapeUtil::MakeShape( + subshape.element_type(), new_dimensions, new_is_dynamic, + new_expressions); + result_subshapes.push_back(new_shape); } return ShapeUtil::MakeTupleShape(result_subshapes); } @@ -3172,7 +3239,9 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { /* static */ absl::StatusOr ShapeInference::InferSliceShape( const Shape& arg, absl::Span starts, - absl::Span limits, absl::Span strides) { + absl::Span limits, absl::Span strides, + absl::Span start_exprs, + absl::Span limit_exprs) { auto error = [&](const std::string& message) { return InvalidArgument( "%s in slice operation; argument shape: %s; starts: {%s}; limits: " @@ -3202,8 +3271,10 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { } std::vector sizes; + std::vector expressions; const auto starts_size = starts.size(); sizes.reserve(starts_size); + expressions.reserve(starts_size); for (int64_t dimension = 0; dimension < starts_size; ++dimension) { int64_t start_index = starts[dimension]; int64_t limit_index = limits[dimension]; @@ -3231,6 +3302,14 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { return InvalidArgument("Stride (%d) must be positive.", stride); } sizes.push_back((limit_index - start_index + stride - 1) / stride); + + auto limit_expr = + limit_exprs.empty() ? DynExpr::_(limit_index) : limit_exprs[dimension]; + auto start_expr = + start_exprs.empty() ? DynExpr::_(start_index) : start_exprs[dimension]; + + auto new_expr = (*(*(*limit_expr - *start_expr) + stride) - 1)->s(); + expressions.push_back((*new_expr/stride)->s()); } std::vector is_dynamic(arg.dimensions_size()); @@ -3242,12 +3321,14 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { is_dynamic[i] = arg.is_bounded_dynamic_dimension(i); } - return ShapeUtil::MakeShape(arg.element_type(), sizes, is_dynamic); + return ShapeUtil::MakeShape(arg.element_type(), sizes, is_dynamic, + expressions); } /* static */ absl::StatusOr ShapeInference::InferDynamicSliceShape( const Shape& operand_shape, absl::Span start_index_shapes, - absl::Span slice_sizes, bool allow_scalar_indices) { + absl::Span slice_sizes, + absl::Span slice_exprs, bool allow_scalar_indices) { TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice")); auto number_of_indices = start_index_shapes.size(); // TODO(b/118437727): Remove this path. @@ -3346,8 +3427,8 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { VLOG(2) << StrFormat("slice_sizes[%d] = %d", dim, slice_dim_size); } - Shape result = - ShapeUtil::MakeShape(operand_shape.element_type(), slice_sizes); + Shape result = ShapeUtil::MakeShape(operand_shape.element_type(), slice_sizes, + slice_exprs); for (int64_t dimension = 0; dimension < operand_shape.dimensions_size(); ++dimension) { @@ -3646,7 +3727,8 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { } /* static */ absl::StatusOr ShapeInference::InferBroadcastShape( - const Shape& operand, absl::Span broadcast_sizes) { + const Shape& operand, absl::Span broadcast_sizes, + absl::Span broadcast_exprs) { // This method is used to infer shape for xla::BroadcastInDim. TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast")); TF_RET_CHECK(!operand.is_unbounded_dynamic()); @@ -3666,12 +3748,14 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { std::copy(operand.dimensions().begin(), operand.dimensions().end(), dimensions.begin() + broadcast_sizes.size()); - TF_ASSIGN_OR_RETURN(Shape result, ShapeUtil::MakeValidatedShape( - operand.element_type(), dimensions)); + TF_ASSIGN_OR_RETURN( + Shape result, ShapeUtil::MakeValidatedShape(operand.element_type(), + dimensions, broadcast_exprs)); for (int64_t i = 0; i < operand.dimensions_size(); ++i) { result.set_dynamic_dimension(broadcast_sizes.size() + i, operand.is_dynamic_dimension(i)); } + return result; } @@ -3735,7 +3819,8 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { /* static */ absl::StatusOr ShapeInference::InferDynamicReshapeShape( const Shape& operand, absl::Span dim_size_shapes, absl::Span new_size_bounds, - const std::vector& dims_are_dynamic) { + const std::vector& dims_are_dynamic, + absl::Span expressions) { if (new_size_bounds.size() != dims_are_dynamic.size()) { return InvalidArgument( "DynamicReshape has to have the same number of elements in new_sizes " @@ -3751,9 +3836,8 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { dim_size_shape->ToString()); } } - Shape inferred_shape = ShapeUtil::MakeShape( - operand.element_type(), new_size_bounds, dims_are_dynamic); + operand.element_type(), new_size_bounds, dims_are_dynamic, expressions); if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) { return InvalidArgument( "Reshape operation has mismatched element counts: from=%d (%s) " @@ -3767,10 +3851,21 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { /* static */ absl::StatusOr ShapeInference::InferReshapeShape( const Shape& operand, absl::Span dimensions, - int64_t inferred_dimension) { + int64_t inferred_dimension, absl::Span expressions) { TF_RETURN_IF_ERROR(ExpectArray(operand, "reshape")); Shape inferred_shape = - ShapeUtil::MakeShape(operand.element_type(), dimensions); + ShapeUtil::MakeShape(operand.element_type(), dimensions, expressions); + + if (expressions.empty() && operand.expressions().size() > 0 && + operand.expressions(0) != nullptr && operand.expressions(0)->is_dynamic()) { + return InvalidArgument("Expressions is empty but operand is dynamic"); + } + + // if (!expressions.empty() && expressions[0]->is_constant() && + // expressions[0]->get_val() == 977) { + // return InvalidArgument("Expressions[0] is the magic number (977)."); + // } + VLOG(3) << "Reshape inferred shape: " << ShapeUtil::HumanString(inferred_shape); @@ -4250,18 +4345,25 @@ static absl::Status ValidateGatherDimensionNumbers( std::vector expanded_start_indices_shape; // Also tracks if an output dimension is dynamic. std::vector expanded_start_indices_shape_dynamic_dimensions; + std::vector expanded_start_indices_shape_expressions; expanded_start_indices_shape.reserve(start_indices_shape.dimensions_size()); expanded_start_indices_shape_dynamic_dimensions.reserve( start_indices_shape.dimensions_size()); + expanded_start_indices_shape_expressions.reserve( + start_indices_shape.dimensions_size()); absl::c_copy(start_indices_shape.dimensions(), std::back_inserter(expanded_start_indices_shape)); absl::c_copy( start_indices_shape.dynamic_dimensions(), std::back_inserter(expanded_start_indices_shape_dynamic_dimensions)); + absl::c_copy( + start_indices_shape.expressions(), + std::back_inserter(expanded_start_indices_shape_expressions)); if (expanded_start_indices_shape.size() == gather_dim_numbers.index_vector_dim()) { expanded_start_indices_shape.push_back(1); expanded_start_indices_shape_dynamic_dimensions.push_back(false); + expanded_start_indices_shape_expressions.push_back(DynExpr::one); } TF_RETURN_IF_ERROR(ValidateGatherDimensionNumbers( @@ -4328,10 +4430,12 @@ static absl::Status ValidateGatherDimensionNumbers( output_dim_bounds.reserve(result_rank); std::vector output_dim_is_dynamic; + std::vector output_expressions; output_dim_is_dynamic.reserve(result_rank); for (int64_t i = 0; i < result_rank; i++) { int64_t current_bound; bool dim_dynamic = false; + DynExpr* expression = DynExpr::_(-80); bool is_window_index = absl::c_binary_search(gather_dim_numbers.offset_dims(), i); if (is_window_index) { @@ -4353,6 +4457,9 @@ static absl::Status ValidateGatherDimensionNumbers( if (slice_sizes[offset_dims_seen] == input_shape.dimensions(offset_dims_seen)) { dim_dynamic = input_shape.is_dynamic_dimension(offset_dims_seen); + expression = input_shape.expressions(offset_dims_seen); + } else { + expression = DynExpr::_(slice_sizes[offset_dims_seen]); } current_bound = slice_sizes[offset_dims_seen++]; } else { @@ -4362,15 +4469,17 @@ static absl::Status ValidateGatherDimensionNumbers( // Forward dynamic dimensions from indices. dim_dynamic = expanded_start_indices_shape_dynamic_dimensions[gather_dims_seen]; - + expression = expanded_start_indices_shape_expressions[gather_dims_seen]; current_bound = expanded_start_indices_shape[gather_dims_seen++]; } output_dim_is_dynamic.push_back(dim_dynamic); + output_expressions.push_back(expression); output_dim_bounds.push_back(current_bound); } - return ShapeUtil::MakeShape(input_shape.element_type(), output_dim_bounds, - output_dim_is_dynamic); + auto s = ShapeUtil::MakeShape(input_shape.element_type(), output_dim_bounds, + output_dim_is_dynamic, output_expressions); + return s; } namespace { diff --git a/third_party/xla/xla/service/shape_inference.h b/third_party/xla/xla/service/shape_inference.h index 15818f01cdfa7e..4f5d0f4b2e3efa 100644 --- a/third_party/xla/xla/service/shape_inference.h +++ b/third_party/xla/xla/service/shape_inference.h @@ -245,13 +245,17 @@ class ShapeInference { // e.g. slice f32[32x32] 0:16 0:16 -> f32[16x16] static absl::StatusOr InferSliceShape( const Shape& arg, absl::Span starts, - absl::Span limits, absl::Span strides); + absl::Span limits, absl::Span strides, + absl::Span start_exprs = {}, + absl::Span limit_exprs = {}); // Infers the shape produced by a dynamic slice operation of size specified // in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'. static absl::StatusOr InferDynamicSliceShape( const Shape& operand_shape, absl::Span start_index_shapes, - absl::Span slice_sizes, bool allow_scalar_indices = true); + absl::Span slice_sizes, + absl::Span slice_exprs = {}, + bool allow_scalar_indices = true); // Infers the shape produced by a dynamic update slice operation based // on the shape of operand and update. @@ -283,7 +287,8 @@ class ShapeInference { // Infers the shape produced by a broadcast operation. static absl::StatusOr InferBroadcastShape( - const Shape& operand, absl::Span broadcast_sizes); + const Shape& operand, absl::Span broadcast_sizes, + absl::Span broadcast_exprs = {}); // Checks whether the given parameters can form a broadcast. Returns the same // output_shape if it's legal. @@ -295,7 +300,7 @@ class ShapeInference { // its operand and the new dimension sizes specified. static absl::StatusOr InferReshapeShape( const Shape& operand, absl::Span dimensions, - int64_t inferred_dimension); + int64_t inferred_dimension, absl::Span expressions = {}); // Infers the shape produced by a dynamic reshape operation from the element // type of its operand and the new dimension sizes specified. The result shape @@ -304,7 +309,8 @@ class ShapeInference { static absl::StatusOr InferDynamicReshapeShape( const Shape& operand, absl::Span dim_size_shapes, absl::Span new_size_bounds, - const std::vector& dims_are_dynamic); + const std::vector& dims_are_dynamic, + absl::Span expressions); // Infers the shape produced by a transpose operation from the element type of // its operand and its dimensions field. diff --git a/third_party/xla/xla/service/triangular_solve_expander.cc b/third_party/xla/xla/service/triangular_solve_expander.cc index 67dc0235810de8..ea587461b1e277 100644 --- a/third_party/xla/xla/service/triangular_solve_expander.cc +++ b/third_party/xla/xla/service/triangular_solve_expander.cc @@ -120,7 +120,12 @@ XlaOp DiagonalBlocks(XlaOp a, int64_t block_size) { auto last_blocks_dims = std::vector(ndims); std::copy(shape_dims.begin(), shape_dims.end(), last_blocks_dims.begin()); last_blocks_dims.insert(last_blocks_dims.end() - 2, 1); - last_blocks = Reshape(last_blocks, last_blocks_dims); + auto shape_exprs = blocks_shape.expressions(); + auto last_blocks_exprs = std::vector(ndims); + std::copy(shape_exprs.begin(), shape_exprs.end(), + last_blocks_exprs.begin()); + last_blocks_exprs.insert(last_blocks_exprs.end() - 2, DynExpr::one); + last_blocks = Reshape(last_blocks, last_blocks_dims, last_blocks_exprs); // Concatenate with the other blocks if necessary if (n > block_size) { @@ -366,7 +371,7 @@ XlaOp TriangularSolveExpander::InvertDiagonalBlocks( /*broadcast_dimensions=*/{0, 1}); // Reshape back to original batch major dimensions - return Reshape(inv_diag_blocks, shape.dimensions()); + return Reshape(inv_diag_blocks, shape.dimensions(), shape.expressions()); }); } diff --git a/third_party/xla/xla/shape.cc b/third_party/xla/xla/shape.cc index 1cee38146fb07d..962984b9225b6a 100644 --- a/third_party/xla/xla/shape.cc +++ b/third_party/xla/xla/shape.cc @@ -41,6 +41,295 @@ limitations under the License. namespace xla { +DynExpr* ExprFromProto(const ExpressionProto& proto) { + switch (proto.node_type_case()) { + case ExpressionProto::kConstantValue: + return DynExpr::_(proto.constant_value()); + + case ExpressionProto::kVariableId: + return DynExpr::V(proto.variable_id()); + + case ExpressionProto::kAddNode: { + const auto& add = proto.add_node(); + return *ExprFromProto(add.lhs()) + *ExprFromProto(add.rhs()); + } + + case ExpressionProto::kSubNode: { + const auto& sub = proto.sub_node(); + return *ExprFromProto(sub.lhs()) - *ExprFromProto(sub.rhs()); + } + + case ExpressionProto::kMulNode: { + const auto& mul = proto.mul_node(); + return *ExprFromProto(mul.lhs()) * *ExprFromProto(mul.rhs()); + } + + case ExpressionProto::kDivNode: { + const auto& div = proto.div_node(); + return *ExprFromProto(div.lhs()) / *ExprFromProto(div.rhs()); + } + + case ExpressionProto::NODE_TYPE_NOT_SET: + default: + return nullptr; + } +} + +DynExpr* operator*(DynExpr& lhs, DynExpr& rhs) { return new Mul(&lhs, &rhs); } +DynExpr* operator*(int64_t k, DynExpr& rhs) { + return new Mul(DynExpr::_(k), &rhs); +} +DynExpr* operator/(DynExpr& lhs, DynExpr& rhs) { return new Div(&lhs, &rhs); } +DynExpr* operator/(DynExpr& lhs, int64_t d) { + return new Div(&lhs, DynExpr::_(d)); +} +DynExpr* operator+(DynExpr& lhs, DynExpr& rhs) { return new Add(&lhs, &rhs); } +DynExpr* operator+(DynExpr& lhs, int64_t d) { + return new Add(&lhs, DynExpr::_(d)); +} +DynExpr* operator-(DynExpr& lhs, DynExpr& rhs) { return new Sub(&lhs, &rhs); } +DynExpr* operator-(DynExpr& lhs, int64_t d) { + return new Sub(&lhs, DynExpr::_(d)); +} +bool operator==(DynExpr& lhs, DynExpr& rhs) { + return DynExpr::equal(&lhs, &rhs); +} +bool operator==(DynExpr& lhs, int64_t d) { + return DynExpr::equal(&lhs, DynExpr::_(d)); +} +bool operator<(DynExpr& lhs, int64_t d) { + return lhs.is_constant() && lhs.get_val() < d; +} + +bool DynExpr::equal(DynExpr* expr1, DynExpr* expr2) { + auto e1 = expr1->s(); + auto e2 = expr2->s(); + if (e1 == nullptr || e2 == nullptr) return false; + Constant* c1 = dynamic_cast(e1); + Constant* c2 = dynamic_cast(e2); + if (c1 && c2) return c1->get_val() == c2->get_val(); + // Var x = Var y <=> x = y + if (Variable* varx = dynamic_cast(e1), + *vary = dynamic_cast(e2); + varx && vary) { + return varx->get_id() == vary->get_id(); + } + // a * b = c * d <=> (a = c /\ b = d) \/ (a = d /\ b = c) + if (Mul* ab = dynamic_cast(e1), *cd = dynamic_cast(e2); + ab && cd) { + auto a = ab->get_lhs(); + auto b = ab->get_rhs(); + auto c = cd->get_lhs(); + auto d = cd->get_rhs(); + return (*a == *c && *b == *d) || (*a == *d && *b == *c); + } + // a / b = c / d <=> (a = c /\ b = d) + if (Div* ab = dynamic_cast(e1), *cd = dynamic_cast(e2); + ab && cd) { + auto a = ab->get_lhs(); + auto b = ab->get_rhs(); + auto c = cd->get_lhs(); + auto d = cd->get_rhs(); + return *a == *c && *b == *d; + } + // a + b = c + d <=> (a = c /\ b = d) \/ (a = d /\ b = c) + if (Add* ab = dynamic_cast(e1), *cd = dynamic_cast(e2); + ab && cd) { + auto a = ab->get_lhs(); + auto b = ab->get_rhs(); + auto c = cd->get_lhs(); + auto d = cd->get_rhs(); + return (*a == *c && *b == *d) || (*a == *d && *b == *c); + } + // a - b = c - d <=> (a = c /\ b = d) + if (Sub* ab = dynamic_cast(e1), *cd = dynamic_cast(e2); + ab && cd) { + auto* a = ab->get_lhs(); + auto* b = ab->get_rhs(); + auto* c = cd->get_lhs(); + auto* d = cd->get_rhs(); + return *a == *c && *b == *d; + } + return false; +} + +// Simplification methods +DynExpr* Constant::s() { return this; } + +DynExpr* Variable::s() { return this; } + +DynExpr* Mul::s() { + DynExpr* s_lhs = get_lhs()->s(); + DynExpr* s_rhs = get_rhs()->s(); + Constant* l = dynamic_cast(s_lhs); + Constant* r = dynamic_cast(s_rhs); + // constant * constant + if (l && r) return DynExpr::_(l->get_val() * r->get_val()); + // 0 * X = 0 + if (l && l->get_val() == 0) return DynExpr::zero; + // 1 * X = X + if (l && l->get_val() == 1) return s_rhs; + // X * 1 = X + if (r && r->get_val() == 1) return s_lhs; + // X * constant = constant * X + if (r && s_lhs->is_dynamic()) return (r->get_val() * *s_lhs)->s(); + // m * (nX) = (m*n) * X + if (Mul* nX = dynamic_cast(s_rhs)) { + DynExpr* X = nX->get_rhs(); + Constant* n = dynamic_cast(nX->get_lhs()); + if (l && n) { + auto mn = l->get_val() * n->get_val(); + return (mn * *X)->s(); + } + } + return (*s_lhs) * (*s_rhs); +} + +DynExpr* Add::s() { + DynExpr* s_lhs = get_lhs()->s(); + DynExpr* s_rhs = get_rhs()->s(); + Constant* l = dynamic_cast(s_lhs); + Constant* r = dynamic_cast(s_rhs); + // constant + constant + if (l && r) return DynExpr::_(l->get_val() + r->get_val()); + // 0 + X = X + if (l && l->get_val() == 0) return s_rhs; + // X + 0 = X + if (r && r->get_val() == 0) return s_lhs; + // m + X = X + m + if (l && s_rhs->is_dynamic()) return (*s_rhs + l->get_val())->s(); + // X + X = 2 * X + if (*s_lhs == *s_rhs) { + return (2 * (*s_rhs))->s(); + } + // nX + X = (n+1) * X + if (Mul* nX = dynamic_cast(s_lhs)) { + DynExpr* n = nX->get_lhs(); + DynExpr* X = nX->get_rhs(); + if (*X == *s_rhs) { + return (*(*n + 1) * (*X))->s(); + } + } + // X + nX = (n+1) * X + if (Mul* nX = dynamic_cast(s_rhs)) { + DynExpr* n = nX->get_lhs(); + DynExpr* X = nX->get_rhs(); + if (*X == *s_lhs) { + return (*(*n + 1) * (*X))->s(); + } + } + // mX + nX = (m+n) * X + if (Mul* mX = dynamic_cast(s_lhs), *nY = dynamic_cast(s_rhs); + mX && nY) { + DynExpr* m = mX->get_lhs(); + DynExpr* X = mX->get_rhs(); + DynExpr* n = nY->get_lhs(); + DynExpr* Y = nY->get_rhs(); + if (*X == *Y) { + return (*(*m + *n) * (*X))->s(); + } + } + // (X + Y) + Z = X + (Y + Z) + if (Add* XY = dynamic_cast(s_lhs)) { + DynExpr* X = XY->get_lhs(); + DynExpr* Y = XY->get_rhs(); + return (*X + *(*Y + *s_rhs))->s(); + } + // (X - Y) + Z = X - (Y - Z) + if (Sub* XY = dynamic_cast(s_lhs)) { + DynExpr* X = XY->get_lhs(); + DynExpr* Y = XY->get_rhs(); + return (*X - *(*Y - *s_rhs))->s(); + } + return *s_lhs + *s_rhs; +} + +DynExpr* Sub::s() { + if (!get_lhs()){ + LOG(INFO) << "NO LEFT"; + } + + if (!get_rhs()){ + LOG(INFO) << "NO RIGHT"; + } + DynExpr* s_lhs = get_lhs()->s(); + DynExpr* s_rhs = get_rhs()->s(); + Constant* l = dynamic_cast(s_lhs); + Constant* r = dynamic_cast(s_rhs); + // constant - constant + if (l && r) return DynExpr::_(l->get_val() - r->get_val()); + // X - 0 = X + if (r && r->get_val() == 0) return s_lhs; + // X - X = 0 + if (*s_lhs == *s_rhs) { + return DynExpr::zero; + } + // mX - nX = (m-n) * X + if (Mul* mX = dynamic_cast(s_lhs), *nY = dynamic_cast(s_rhs); + mX && nY) { + DynExpr* m = mX->get_lhs(); + DynExpr* X = mX->get_rhs(); + DynExpr* n = nY->get_lhs(); + DynExpr* Y = nY->get_rhs(); + if (*X == *Y) { + return (*(*m - *n) * (*X))->s(); + } + } + // (X + Y) - X = X + (Y - Z) + if (Add* XY = dynamic_cast(s_lhs)) { + DynExpr* X = XY->get_lhs(); + DynExpr* Y = XY->get_rhs(); + return (*X + *(*Y - *s_rhs))->s(); + } + // (X - Y) - Z = X - (Y + Z) + if (Sub* XY = dynamic_cast(s_lhs)) { + DynExpr* X = XY->get_lhs(); + DynExpr* Y = XY->get_rhs(); + return (*X - *(*Y + *s_rhs))->s(); + } + return *s_lhs - *s_rhs; +} + +DynExpr* Div::s() { + DynExpr* s_lhs = get_lhs()->s(); + DynExpr* s_rhs = get_rhs()->s(); + Constant* l = dynamic_cast(s_lhs); + Constant* r = dynamic_cast(s_rhs); + // constant / constant + if (l && r) return DynExpr::_(l->get_val() / r->get_val()); + // X / 1 = X + if (r && r->get_val() == 1) return s_lhs; + // (X + Y) / Z = (X/Z) + (Y/Z) + if (Add* XY = dynamic_cast(s_lhs)) { + DynExpr* X = XY->get_lhs(); + DynExpr* Y = XY->get_rhs(); + return (*((*X) / (*s_rhs)) + *((*Y) / (*s_rhs)))->s(); + } + // (X * Y) / Z = (X/Z) * Y + if (Mul* XY = dynamic_cast(s_lhs)) { + DynExpr* X = XY->get_lhs(); + DynExpr* Y = XY->get_rhs(); + return (*(*X / (*s_rhs)) * (*Y))->s(); + } + // (X / Y) / Z = X / (Y*Z) + if (Div* XY = dynamic_cast(s_lhs)) { + DynExpr* X = XY->get_lhs(); + DynExpr* Y = XY->get_rhs(); + return (*X / *(*Y * *s_rhs))->s(); + } + return *s_lhs / *s_rhs; +} + +std::ostream& operator<<(std::ostream& os, DynExpr* expr) { + ExpressionProto proto; + expr->to_proto(&proto); + os << proto.ShortDebugString(); + return os; +} + +DynExpr* DynExpr::zero = new Constant(0); +DynExpr* DynExpr::one = new Constant(1); + // Defined in .cc file to avoid inlining these large routines Shape::Shape() = default; Shape::~Shape() = default; @@ -97,13 +386,18 @@ Shape::Shape(const ShapeProto& shape_proto) { } absl::StatusOr Shape::FromProto(const ShapeProto& shape_proto) { + + // LOG(INFO) << "FROM PROTO:\n" << shape_proto.DebugString() << std::endl; + Shape shape; shape.set_element_type(shape_proto.element_type()); if (auto* const state = shape.if_array_state()) { const int num_dims = shape_proto.dimensions_size(); const int num_is_dynamic_dims = shape_proto.is_dynamic_dimension_size(); + const int num_expressions = shape_proto.expressions_size(); state->dimensions.reserve(num_dims); state->dynamic_dimensions.reserve(num_dims); + state->expressions.reserve(num_dims); if (num_is_dynamic_dims != 0) { TF_RET_CHECK(num_dims == num_is_dynamic_dims) << "Malformed shape proto: number of is_dynamic_dimension " @@ -111,6 +405,13 @@ absl::StatusOr Shape::FromProto(const ShapeProto& shape_proto) { << num_is_dynamic_dims << ") does not match number of dimension " << "fields (" << num_dims << ")."; } + if (num_expressions != 0) { + TF_RET_CHECK(num_dims == num_expressions) + << "Malformed shape proto: number of expressions " + "fields (" + << num_expressions << ") does not match number of dimension " + << "fields (" << num_dims << ")."; + } for (int i = 0; i < num_dims; ++i) { const bool is_dynamic = (i < num_is_dynamic_dims) && shape_proto.is_dynamic_dimension(i); @@ -118,7 +419,11 @@ absl::StatusOr Shape::FromProto(const ShapeProto& shape_proto) { // UnsafeAddDimension. We expect that the caller will eventually call a // validation routine that will detect the error in case the dimension // value is invalid. - shape.UnsafeAddDimension(shape_proto.dimensions(i), is_dynamic); + DynExpr* expression = (i < num_expressions) + ? ExprFromProto(shape_proto.expressions(i)) + : DynExpr::_(shape_proto.dimensions(i)); + shape.UnsafeAddDimension(shape_proto.dimensions(i), is_dynamic, + expression); } } else if (auto* const state = shape.if_tuple_state()) { state->tuple_shapes.reserve(shape_proto.tuple_shapes_size()); @@ -136,6 +441,7 @@ absl::StatusOr Shape::FromProto(const ShapeProto& shape_proto) { TF_ASSIGN_OR_RETURN(*shape.mutable_layout(), Layout::FromProto(shape_proto.layout())); } + // LOG(INFO) << "FROM PROTO " << shape << "\n"; return shape; } @@ -143,6 +449,8 @@ ShapeProto Shape::ToProto() const { ShapeProto proto; proto.set_element_type(element_type_); + // LOG(INFO) << "TO PROTO " << ToString() << "\n"; + if (const auto* const state = if_array_state()) { proto.mutable_dimensions()->Reserve(state->dimensions.size()); for (const int64_t dimension : state->dimensions) { @@ -151,6 +459,11 @@ ShapeProto Shape::ToProto() const { for (const bool dynamic : state->dynamic_dimensions) { proto.add_is_dynamic_dimension(dynamic); } + for (const DynExpr* e : state->expressions) { + ExpressionProto* eproto = proto.add_expressions(); + CHECK(e != nullptr) << "Missing expression in expression list."; + e->to_proto(eproto); + } if (state->layout.has_value()) { *proto.mutable_layout() = state->layout->ToProto(); } @@ -163,6 +476,7 @@ ShapeProto Shape::ToProto() const { proto.mutable_tuple_shapes()->Reserve(1); *proto.add_tuple_shapes() = state->buffer_shape[0].ToProto(); } + // LOG(INFO) << "DEBUG VIEW:\n" << proto.DebugString() << std::endl; return proto; } @@ -245,14 +559,15 @@ bool Shape::AreAllLeavesIntegers() const { return primitive_util::IsIntegralType(element_type()); } -void Shape::add_dimensions(int64_t value, bool is_dynamic) { +void Shape::add_dimensions(int64_t value, bool is_dynamic, DynExpr* expr) { if (value < 0) { CHECK(is_dynamic) << "static dimension must have size >= 0 instead of " << value << "."; CHECK_EQ(value, kUnboundedSize) << "dynamic dimension must have size == kUnboundedSize or >= 0."; } - UnsafeAddDimension(value, is_dynamic); + UnsafeAddDimension(value, is_dynamic, + expr != nullptr ? expr : DynExpr::_(value)); } void Shape::set_dynamic_dimension(int dimension, bool is_dynamic) { @@ -262,6 +577,23 @@ void Shape::set_dynamic_dimension(int dimension, bool is_dynamic) { state.dynamic_dimensions[dimension] = is_dynamic; } +void Shape::set_expression(int dimension, DynExpr* e) { + auto& state = array_state(); + state.expressions[dimension] = + e != nullptr ? e : DynExpr::_(state.dimensions[dimension]); +} + +void Shape::set_expressions(std::vector exps) { + auto& state = array_state(); + CHECK_LE(exps.size(), state.dimensions.size()); + state.expressions.resize(state.dimensions.size()); + for (size_t i = 0; i < state.dimensions.size(); ++i) { + DynExpr* expr = i < exps.size() ? exps[i] : DynExpr::_(state.dimensions[i]); + state.expressions[i] = + expr != nullptr ? expr : DynExpr::_(state.dimensions[i]); + } +} + void Shape::set_dimensions(int index, int64_t size, std::optional is_dynamic) { auto& state = array_state(); @@ -270,6 +602,7 @@ void Shape::set_dimensions(int index, int64_t size, CheckDimensionSize(index, size, dynamic); state.dimensions[index] = size; state.dynamic_dimensions[index] = dynamic; + state.expressions[index] = DynExpr::_(size); } void Shape::set_dimensions_minor(int index, int64_t size, @@ -291,12 +624,15 @@ void Shape::CheckDimensionSize(int dim_index, int64_t size, bool is_dynamic) { } } -void Shape::UnsafeAddDimension(int64_t value, bool is_dynamic) { +void Shape::UnsafeAddDimension(int64_t value, bool is_dynamic, DynExpr* exp) { auto& state = array_state(); CHECK_EQ(state.dimensions.size(), state.dynamic_dimensions.size()) << "where the shape is " << ToString(); + CHECK_EQ(state.dimensions.size(), state.expressions.size()) + << "where the shape is " << ToString(); state.dimensions.push_back(value); state.dynamic_dimensions.push_back(is_dynamic); + state.expressions.push_back(exp != nullptr ? exp : DynExpr::_(value)); } bool Shape::is_static() const { @@ -345,6 +681,8 @@ void Shape::DeleteDimension(int64_t dim_to_delete) { state.dimensions.erase(state.dimensions.begin() + dim_to_delete); state.dynamic_dimensions.erase(state.dynamic_dimensions.begin() + dim_to_delete); + state.expressions.erase(state.expressions.begin() + + dim_to_delete); if (LayoutUtil::HasLayout(*this)) { state.layout->DeleteDimension(dim_to_delete); // NOLINT: optional-access } @@ -358,6 +696,8 @@ void Shape::DeleteDimensions(absl::Span dims_to_delete) { state.dimensions = RemoveElements(sorted_dims_to_delete, state.dimensions); state.dynamic_dimensions = RemoveElements(sorted_dims_to_delete, state.dynamic_dimensions); + state.expressions = + RemoveElements(sorted_dims_to_delete, state.expressions); if (LayoutUtil::HasLayout(*this)) { for (auto it = sorted_dims_to_delete.rbegin(); it != sorted_dims_to_delete.rend(); ++it) { @@ -370,6 +710,7 @@ void Shape::CheckStateIsEmpty() const { if (const auto* const state = if_array_state()) { CHECK(state->dimensions.empty()) << ToString(); CHECK(state->dynamic_dimensions.empty()) << ToString(); + CHECK(state->expressions.empty()) << ToString(); CHECK(!state->layout.has_value()) << ToString(); } else if (const auto* const state = if_tuple_state()) { CHECK(state->tuple_shapes.empty()) << ToString(); @@ -509,6 +850,11 @@ bool Shape::Equal::operator()(const Shape& lhs, const Shape& rhs) { rhs.is_unbounded_dynamic_dimension(i))) { continue; } + if (i == 0 && ignore_batch_ && + (lhs.outer_multiplier() > 0 || rhs.outer_multiplier() > 0)) { + VLOG(3) << "CompareShapes: batch dimension found. Forcely compatible"; + continue; + } if (lhs.dimensions(i) != rhs.dimensions(i)) { VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions"; return false; diff --git a/third_party/xla/xla/shape.h b/third_party/xla/xla/shape.h index 8453dc17717e11..1c5197ad838a88 100644 --- a/third_party/xla/xla/shape.h +++ b/third_party/xla/xla/shape.h @@ -38,6 +38,7 @@ limitations under the License. #include "xla/tsl/platform/logging.h" // IWYU pragma: keep #include "xla/util.h" #include "xla/xla_data.pb.h" +#include "xla/shape_dynexpr.h" namespace xla { @@ -218,6 +219,29 @@ class Shape { return array_state().dynamic_dimensions[dimension]; } + bool has_dynamic_expr() const { + if (auto* const state = if_array_state()) { + return absl::c_any_of(state->expressions, + [](DynExpr* e) { + return e != nullptr && e->is_dynamic(); + }); + } + if (auto* const state = if_tuple_state()) { + return absl::c_any_of(state->tuple_shapes, [](Shape subshape) { + return subshape.has_dynamic_expr(); + }); + } + return false; + } + + DynExpr* expressions(int dimension) const { + if (dimension < 0) return DynExpr::_(-999); + const auto& exprs = array_state().expressions; + const size_t dim = static_cast(dimension); + if (dim >= exprs.size()) return DynExpr::_(-999); + return exprs[dim] != nullptr ? exprs[dim] : DynExpr::_(-999); + } + // Returns true if the given dimension is statically-sized. // Precondition: this is an array shape and `dimension` is a valid dimension // index. @@ -232,12 +256,20 @@ class Shape { // - The dimension's size is valid for the given dynamic-ness. void set_dynamic_dimension(int dimension, bool is_dynamic); + void set_expression(int dimension, DynExpr* e); + + void set_expressions(std::vector exprs); + // Returns a span to indicate whether each dimension is dynamic. // Precondition: this is an array shape. absl::Span dynamic_dimensions() const { return array_state().dynamic_dimensions; } + absl::Span expressions() const { + return array_state().expressions; + } + // Removes the given dimension from the shape. Layout, if it exists, is // adjusted to match the modified shape. // Precondition: this is an array shape, and the input dimension indices are @@ -313,7 +345,8 @@ class Shape { // - This is an array shape. // - Either `value` is >= 0, or `is_dynamic` is true and `value` is // kUnboundedSize. - void add_dimensions(int64_t value, bool is_dynamic = false); + void add_dimensions(int64_t value, bool is_dynamic = false, + xla::DynExpr* expr = nullptr); // Clears all dimensions (i.e. makes this shape a scalar). // Precondition: this is an array shape. @@ -321,6 +354,7 @@ class Shape { auto& state = array_state(); state.dimensions.clear(); state.dynamic_dimensions.clear(); + state.expressions.clear(); } // Returns a span to indicate the size of each dimension. @@ -434,6 +468,10 @@ class Shape { bool operator()(const Shape& lhs, const Shape& rhs); + Equal& IgnoreBatch(bool ignore_batch = true) { + ignore_batch_ = ignore_batch; + return *this; + } Equal& IgnoreLayout(bool ignore_layout = true) { ignore_layout_ = ignore_layout; return *this; @@ -488,6 +526,7 @@ class Shape { } private: + bool ignore_batch_ = false; bool ignore_layout_ = false; bool ignore_tiles_in_layout_ = false; bool ignore_element_size_in_layout_ = false; @@ -515,7 +554,7 @@ class Shape { } if (const auto* const state = s.if_array_state()) { h = H::combine(std::move(h), s.element_type_, state->dimensions, - state->dynamic_dimensions); + state->dynamic_dimensions, state->expressions); if (kIsLayoutSensitive) { h = H::combine(std::move(h), state->layout); } @@ -532,7 +571,11 @@ class Shape { return Shape::Hash(std::move(h), s); } + int64_t outer_multiplier() const { return outer_multiplier_; } + void set_outer_multiplier(int64_t m) { outer_multiplier_ = m; } private: + int64_t outer_multiplier_ = -1; + friend absl::Status ValidateNonLayoutProperties(const Shape& shape); // Define one state struct for each shape category. Depending on the element @@ -560,6 +603,8 @@ class Shape { // respective dimension is dynamically sized. absl::InlinedVector dynamic_dimensions; + absl::InlinedVector expressions; + // The layout of the shape. std::optional layout; }; @@ -583,7 +628,8 @@ class Shape { // Instead, we rely on validation down the road to catch invalid shapes. // This is useful for code that should not crash, such as constructing a // Shape from an unvalidated proto. - void UnsafeAddDimension(int64_t value, bool is_dynamic); + void UnsafeAddDimension(int64_t value, bool is_dynamic, + DynExpr* exp = nullptr); // Convenience accessors for the state_ variant. Each if_*_state() accessor // returns a pointer to the corresponding state struct, or nullptr if the diff --git a/third_party/xla/xla/shape_dynexpr.h b/third_party/xla/xla/shape_dynexpr.h new file mode 100644 index 00000000000000..5c1f5645c25e0e --- /dev/null +++ b/third_party/xla/xla/shape_dynexpr.h @@ -0,0 +1,379 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SHAPE_DYNEXPR_H_ +#define XLA_SHAPE_DYNEXPR_H_ + +#include +#include +#include + +#include "xla/printer.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +class DynExpr { + public: + virtual ~DynExpr() = default; + virtual void print(xla::Printer* printer) const = 0; + virtual void to_proto(xla::ExpressionProto* proto) const = 0; + virtual bool is_constant() const = 0; + virtual int64_t get_val() const { return -1; } + virtual DynExpr* s() = 0; // simplify + virtual DynExpr* substitute(int id, DynExpr* v) = 0; + virtual std::set get_all_ids() = 0; + virtual int64_t solve(int64_t x) = 0; + + bool is_dynamic() { return !is_constant(); } + + static DynExpr* zero; + static DynExpr* one; + static DynExpr* _(int64_t val); + static DynExpr* V(int var_id); + static DynExpr* _s(DynExpr* expr); + static bool equal(DynExpr* expr1, DynExpr* expr2); + + friend std::ostream& operator<<(std::ostream& os, DynExpr* expr); +}; + +// constant i +class Constant : public DynExpr { + int64_t value; + + public: + explicit Constant(int64_t v) : value(v) {} + void print(xla::Printer* printer) const override { + if (value < 0) { + printer->Append("("); + } + printer->Append(value); + if (value < 0) { + printer->Append(")"); + } + } + void to_proto(xla::ExpressionProto* proto) const override { + proto->set_constant_value(value); + } + bool is_constant() const override { return true; } + int64_t get_val() const override { return value; } + DynExpr* substitute(int id, DynExpr* v) { return this; } + std::set get_all_ids() { return {}; } + int64_t solve(int64_t x) { return -1; } + DynExpr* s() override; +}; + +// var id (int) +class Variable : public DynExpr { + int id; + + public: + explicit Variable(int identifier) : id(identifier) {} + void print(xla::Printer* printer) const override { + // printer->Append("(Var "); + char letter = 'A' + (id - 1); + printer->Append(std::string(1, letter)); + // printer->Append(")"); + } + void to_proto(xla::ExpressionProto* proto) const override { + proto->set_variable_id(id); + } + bool is_constant() const override { return false; } + int get_id() const { return id; } + DynExpr* substitute(int id, DynExpr* v) { return get_id() == id ? v : this;} + std::set get_all_ids() { return {get_id()}; } + int64_t solve(int64_t x) { return x; } + DynExpr* s() override; +}; + +// exp = exp + exp +class Add : public DynExpr { + DynExpr* lhs; + DynExpr* rhs; + + public: + Add(DynExpr* l, DynExpr* r) : lhs(std::move(l)), rhs(std::move(r)) {} + void print(xla::Printer* printer) const override { + printer->Append("("); + lhs->print(printer); + printer->Append(" + "); + rhs->print(printer); + printer->Append(")"); + } + + DynExpr* get_lhs() const { return lhs; } + DynExpr* get_rhs() const { return rhs; } + + void to_proto(xla::ExpressionProto* proto) const override { + auto* add_msg = proto->mutable_add_node(); + lhs->to_proto(add_msg->mutable_lhs()); + rhs->to_proto(add_msg->mutable_rhs()); + } + + bool is_constant() const override { + return lhs->is_constant() && rhs->is_constant(); + } + + int64_t get_val() const override { return lhs->get_val() + rhs->get_val(); } + + DynExpr* substitute(int id, DynExpr* v) { + return new Add(lhs->substitute(id, v), rhs->substitute(id, v)); + } + + std::set get_all_ids() { + auto s = lhs->get_all_ids(); + s.merge(rhs->get_all_ids()); + return s; + } + + int64_t solve(int64_t x) { + // Cannot solve if both lhs and rhs are dynamic... + if (lhs->is_dynamic() && rhs->is_dynamic()) return -1; + if (lhs->get_all_ids().size() == 1) { + // (A + c) = x <=> A = x - c => solve A = y with y = x - c + return lhs->solve(x - rhs->get_val()); + } + if (rhs->get_all_ids().size() == 1) { + // (c + A) = x <=> A = x - c => solve A = y with y = x - c + return rhs->solve(x - lhs->get_val()); + } + // No solution + return -1; + } + + DynExpr* s() override; + + ~Add() { + delete lhs; + delete rhs; + } +}; + +// exp = exp - exp +class Sub : public DynExpr { + DynExpr* lhs; + DynExpr* rhs; + + public: + Sub(DynExpr* l, DynExpr* r) : lhs(std::move(l)), rhs(std::move(r)) {} + void print(xla::Printer* printer) const override { + printer->Append("("); + lhs->print(printer); + printer->Append(" - "); + rhs->print(printer); + printer->Append(")"); + } + + DynExpr* get_lhs() const { return lhs; } + DynExpr* get_rhs() const { return rhs; } + + void to_proto(xla::ExpressionProto* proto) const override { + auto* sub_msg = proto->mutable_sub_node(); + lhs->to_proto(sub_msg->mutable_lhs()); + rhs->to_proto(sub_msg->mutable_rhs()); + } + + bool is_constant() const override { + return lhs->is_constant() && rhs->is_constant(); + } + + int64_t get_val() const override { return lhs->get_val() - rhs->get_val(); } + + DynExpr* substitute(int id, DynExpr* v) { + return new Sub(lhs->substitute(id, v), rhs->substitute(id, v)); + } + + std::set get_all_ids() { + auto s = lhs->get_all_ids(); + s.merge(rhs->get_all_ids()); + return s; + } + + int64_t solve(int64_t x) { + // Cannot solve if both lhs and rhs are dynamic... + if (lhs->is_dynamic() && rhs->is_dynamic()) return -1; + if (lhs->get_all_ids().size() == 1) { + // (A - c) = x <=> A = x + c => solve A = y with y = x + c + return lhs->solve(x + rhs->get_val()); + } + if (rhs->get_all_ids().size() == 1) { + // (c + A) = x <=> A = x - c => solve A = y with y = x + c + return rhs->solve(x + lhs->get_val()); + } + // No solution + return -1; + } + + DynExpr* s() override; + + ~Sub() { + delete lhs; + delete rhs; + } +}; + +// exp = exp * exp +class Mul : public DynExpr { + DynExpr* lhs; + DynExpr* rhs; + + public: + Mul(DynExpr* l, DynExpr* r) : lhs(std::move(l)), rhs(std::move(r)) {} + void print(xla::Printer* printer) const override { + printer->Append("("); + lhs->print(printer); + printer->Append(" * "); + rhs->print(printer); + printer->Append(")"); + } + + DynExpr* get_lhs() const { return lhs; } + DynExpr* get_rhs() const { return rhs; } + + void to_proto(xla::ExpressionProto* proto) const override { + auto* mul_msg = proto->mutable_mul_node(); + lhs->to_proto(mul_msg->mutable_lhs()); + rhs->to_proto(mul_msg->mutable_rhs()); + } + + bool is_constant() const override { + return lhs->is_constant() && rhs->is_constant(); + } + + int64_t get_val() const override { return lhs->get_val() * rhs->get_val(); } + + DynExpr* substitute(int id, DynExpr* v) { + return new Mul(lhs->substitute(id, v), rhs->substitute(id, v)); + } + + std::set get_all_ids() { + auto s = lhs->get_all_ids(); + s.merge(rhs->get_all_ids()); + return s; + } + + int64_t solve(int64_t x) { + // Cannot solve if both lhs and rhs are dynamic... + if (lhs->is_dynamic() && rhs->is_dynamic()) return -1; + if (lhs->get_all_ids().size() == 1) { + // (A * c) = x <=> A = x / c => solve A = y with y = x / c + int64_t c = rhs->get_val(); + if (x % c != 0) return -1; + return lhs->solve(x / c); + } + if (rhs->get_all_ids().size() == 1) { + // (c * A) = x <=> A = x / c => solve A = y with y = x / c + int64_t c = lhs->get_val(); + if (x % c != 0) return -1; + return rhs->solve(x / c); + } + // No solution + return -1; + } + + DynExpr* s() override; + + ~Mul() { + delete lhs; + delete rhs; + } +}; + +// expr / expr +class Div : public DynExpr { + DynExpr* lhs; + DynExpr* rhs; + + public: + Div(DynExpr* l, DynExpr* r) : lhs(std::move(l)), rhs(std::move(r)) {} + void print(xla::Printer* printer) const override { + printer->Append("("); + lhs->print(printer); + printer->Append(" / ( "); + rhs->print(printer); + printer->Append(") )"); + } + + DynExpr* get_lhs() const { return lhs; } + DynExpr* get_rhs() const { return rhs; } + + void to_proto(xla::ExpressionProto* proto) const override { + auto* div_msg = proto->mutable_div_node(); + lhs->to_proto(div_msg->mutable_lhs()); + rhs->to_proto(div_msg->mutable_rhs()); + } + + bool is_constant() const override { + return lhs->is_constant() && rhs->is_constant(); + } + + int64_t get_val() const override { return lhs->get_val() / rhs->get_val(); } + + DynExpr* substitute(int id, DynExpr* v) { + return new Div(lhs->substitute(id, v), rhs->substitute(id, v)); + } + + DynExpr* s() override; + + std::set get_all_ids() { + auto s = lhs->get_all_ids(); + s.merge(rhs->get_all_ids()); + return s; + } + + int64_t solve(int64_t x) { + // Cannot solve if both lhs and rhs are dynamic... + if (lhs->is_dynamic() && rhs->is_dynamic()) return -1; + if (lhs->get_all_ids().size() == 1) { + // (A / c) = x <=> A = x * c => solve A = y with y = x * c + return lhs->solve(x * rhs->get_val()); + } + if (rhs->get_all_ids().size() == 1) { + // (c / A) = x <=> A = c / x => solve A = y with y = c / x + int64_t c = lhs->get_val(); + if (c % x != 0) return -1; + return rhs->solve(c / x); + } + // No solution + return -1; + } + + ~Div() { + delete lhs; + delete rhs; + } +}; + +DynExpr* operator*(DynExpr& lhs, DynExpr& rhs); +DynExpr* operator*(int64_t k, DynExpr& rhs); +DynExpr* operator/(DynExpr& lhs, DynExpr& rhs); +DynExpr* operator/(DynExpr& lhs, int64_t d); +DynExpr* operator+(DynExpr& lhs, DynExpr& rhs); +DynExpr* operator+(DynExpr& lhs, int64_t d); +DynExpr* operator-(DynExpr& lhs, DynExpr& rhs); +DynExpr* operator-(DynExpr& lhs, int64_t d); +bool operator==(DynExpr& lhs, DynExpr& rhs); +bool operator==(DynExpr& lhs, int64_t d); + +inline DynExpr* DynExpr::_(int64_t val) { + if (val == 0) return DynExpr::zero; + if (val == 1) return DynExpr::one; + return new Constant(val); +} +inline DynExpr* DynExpr::V(int var_id) { return new Variable(var_id); } + +} // namespace xla + +#endif // XLA_SHAPE_DYNEXPR_H_ diff --git a/third_party/xla/xla/shape_util.cc b/third_party/xla/xla/shape_util.cc index 4f92ce19adb1f9..2a72f66d943dfd 100644 --- a/third_party/xla/xla/shape_util.cc +++ b/third_party/xla/xla/shape_util.cc @@ -267,9 +267,20 @@ static std::vector MakeDynamicDimensions( return dynamic_dimensions; } +static std::vector MakeExpressions( + absl::Span dimensions) { + std::vector expressions; + expressions.reserve(dimensions.size()); + for (int64_t d : dimensions) { + expressions.push_back(DynExpr::_(d)); + } + return expressions; +} + /* static */ Shape ShapeUtil::MakeShape(PrimitiveType element_type, - absl::Span dimensions) { - return MakeValidatedShape(element_type, dimensions).value(); + absl::Span dimensions, + absl::Span expressions) { + return MakeValidatedShape(element_type, dimensions, expressions).value(); } /* static */ Shape ShapeUtil::MakeScalarShape(PrimitiveType element_type) { @@ -278,8 +289,10 @@ static std::vector MakeDynamicDimensions( /* static */ Shape ShapeUtil::MakeShape( PrimitiveType element_type, absl::Span dimensions, - const std::vector& dynamic_dimensions) { - return MakeValidatedShape(element_type, dimensions, dynamic_dimensions) + const std::vector& dynamic_dimensions, + absl::Span expressions) { + return MakeValidatedShape(element_type, dimensions, dynamic_dimensions, + expressions) .value(); } @@ -296,20 +309,27 @@ static std::vector MakeDynamicDimensions( } /* static */ absl::StatusOr ShapeUtil::MakeValidatedShape( - PrimitiveType element_type, absl::Span dimensions) { - return MakeValidatedShape(element_type, dimensions, - MakeDynamicDimensions(dimensions)); + PrimitiveType element_type, absl::Span dimensions, + absl::Span expressions) { + return MakeValidatedShape( + element_type, dimensions, MakeDynamicDimensions(dimensions), + expressions.empty() ? MakeExpressions(dimensions) : expressions); } /* static */ absl::StatusOr ShapeUtil::MakeValidatedShape( PrimitiveType element_type, absl::Span dimensions, - const std::vector& dynamic_dimensions) { + const std::vector& dynamic_dimensions, + absl::Span expressions) { if (dynamic_dimensions.size() != dimensions.size()) { return InvalidArgument( "dynamic dimensions size %d did not match number of dimensions %d", dynamic_dimensions.size(), dimensions.size()); } - + if (expressions.size() != dimensions.size()) { + return InvalidArgument( + "expressions size %d did not match number of dimensions %d", + expressions.size(), dimensions.size()); + } Shape shape; int64_t dense_shape_size = primitive_util::IsArrayType(element_type) ? primitive_util::ByteWidth(element_type) @@ -328,6 +348,7 @@ static std::vector MakeDynamicDimensions( for (int i = 0; i < ndims; i++) { const int64_t d = dimensions[i]; const bool is_dynamic = dynamic_dimensions[i]; + DynExpr* expression = expressions[i]; if (!Shape::IsValidDimensionSize(d, is_dynamic)) { return InvalidArgument("Invalid dimension size %d, is_dynamic=%s", d, is_dynamic ? "true" : "false"); @@ -339,7 +360,7 @@ static std::vector MakeDynamicDimensions( any_overflows |= overflow; } - shape.add_dimensions(d, is_dynamic); + shape.add_dimensions(d, is_dynamic, expression); minor_to_major->push_back(ndims - 1 - i); } @@ -408,10 +429,14 @@ static std::vector MakeDynamicDimensions( } /* static */ Shape ShapeUtil::MakeShapeWithDescendingLayout( - PrimitiveType element_type, absl::Span dimensions) { + PrimitiveType element_type, absl::Span dimensions, + absl::Span expressions) { std::vector layout(dimensions.size()); std::iota(layout.rbegin(), layout.rend(), static_cast(0)); - return MakeShapeWithDenseLayout(element_type, dimensions, layout); + auto shape = MakeShapeWithDenseLayout(element_type, dimensions, layout); + std::vector exprs(expressions.begin(), expressions.end()); + shape.set_expressions(exprs); + return shape; } /* static */ Shape @@ -442,6 +467,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( dim = LayoutUtil::Major(shape.layout(), dim); } new_shape.set_dynamic_dimension(i, shape.is_dynamic_dimension(dim)); + new_shape.set_expression(i, shape.expressions(dim)); } new_shape.mutable_layout()->set_memory_space(shape.layout().memory_space()); return new_shape; @@ -731,7 +757,27 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { printer->Append("?"); } } else { + // Only print constant expression if it is different than the dimension + // (i.e. it is wrong!) + DynExpr* expr = shape.expressions(i); + bool is_wrong = expr != nullptr && expr->is_constant() && + expr->get_val() != shape.dimensions(i); printer->Append(shape.dimensions(i)); + if (is_wrong) { + xla::StringPrinter expr_printer; + expr->print(&expr_printer); + LOG(ERROR) << "Mismatched static shape expression at dim " << i + << ": dim=" << shape.dimensions(i) + << ", expr=" << std::move(expr_printer).ToString(); + printer->Append("print(printer); + printer->Append("!>"); + } + if (expr != nullptr && expr->is_dynamic()) { + printer->Append("<"); + expr->print(printer); + printer->Append(">"); + } } }; print_dimension(0); @@ -753,6 +799,11 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { return; } PrintHumanString(printer, shape); + if (shape.outer_multiplier() > 0) { + printer->Append("(bm="); + printer->Append(shape.outer_multiplier()); + printer->Append(")"); + } if (!shape.IsArray()) return; if (!shape.has_layout()) return; if (IsScalar(shape)) { @@ -811,6 +862,10 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { const Shape& rhs) { if (!SameRank(lhs, rhs)) return false; for (int i = 0; i < lhs.dimensions().size(); ++i) { + if (i == 0 && (lhs.outer_multiplier() > 0 || rhs.outer_multiplier() > 0)) { + VLOG(3) << "CompareShapes: batch dimension found. Forcely compatible"; + continue; + } if (!lhs.is_unbounded_dynamic_dimension(i) && !rhs.is_unbounded_dynamic_dimension(i) && lhs.dimensions(i) != rhs.dimensions(i)) { @@ -826,7 +881,7 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { } /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) { - return Shape::Equal().IgnoreDynamicDimension().IgnoreLayout()(lhs, rhs); + return Shape::Equal().IgnoreDynamicDimension().IgnoreLayout().IgnoreBatch()(lhs, rhs); } /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs, @@ -871,6 +926,11 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { return shape.dimensions(GetDimensionNumber(shape, dimension_number)); } +/* static */ DynExpr* ShapeUtil::GetExpression(const Shape& shape, + int64_t dimension_number) { + return shape.expressions(GetDimensionNumber(shape, dimension_number)); +} + /* static */ int64_t ShapeUtil::GetDimensionNumber(const Shape& shape, int64_t dimension_number) { if (dimension_number < 0) { @@ -1207,8 +1267,11 @@ ShapeUtil::PackedFactorFor1DInterleavedArray(const Shape& shape) { const auto permuted_dims = Permute(shape.dimensions(), permutation); const auto permuted_dynamic_dims = Permute(shape.dynamic_dimensions(), permutation); + const auto permuted_expressions = + Permute(shape.expressions(), permutation); for (int i = 0; i < permuted_dims.size(); ++i) { - new_shape.add_dimensions(permuted_dims[i], permuted_dynamic_dims[i]); + new_shape.add_dimensions(permuted_dims[i], permuted_dynamic_dims[i], + permuted_expressions[i]); } // If `shape` has a layout, by contract we choose a new layout such that the diff --git a/third_party/xla/xla/shape_util.h b/third_party/xla/xla/shape_util.h index 4e3602671fa153..7e9ed98719ff4f 100644 --- a/third_party/xla/xla/shape_util.h +++ b/third_party/xla/xla/shape_util.h @@ -330,6 +330,10 @@ class ShapeUtil { // GetDimensionNumber(dimension_number). static int64_t GetDimension(const Shape& shape, int64_t dimension_number); + // Extracts the shape's expressions at dimension number + // GetDimensionNumber(dimension_number). + static DynExpr* GetExpression(const Shape& shape, int64_t dimension_number); + // Resolves a dimension number, supporting negative indexing. // // Negative indexing has similar semantics to Python. For an N-dimensional @@ -406,7 +410,8 @@ class ShapeUtil { // Constructs a new shape with the given element type and sequence of // dimensions. static Shape MakeShape(PrimitiveType element_type, - absl::Span dimensions); + absl::Span dimensions, + absl::Span expressions = {}); // Make a scalar shape with given primitive type. static Shape MakeScalarShape(PrimitiveType element_type); @@ -419,7 +424,8 @@ class ShapeUtil { // the same size. static Shape MakeShape(PrimitiveType element_type, absl::Span dimensions, - const std::vector& dynamic_dimensions); + const std::vector& dynamic_dimensions, + absl::Span expressions); // Constructs a new buffer shape with the given element type, and sequence of // dimensions. static Shape MakeBufferShape(PrimitiveType element_type, @@ -430,10 +436,13 @@ class ShapeUtil { // size fits in std::numeric_limits::max(), and dynamic size is not // marked static. static absl::StatusOr MakeValidatedShape( - PrimitiveType element_type, absl::Span dimensions); + PrimitiveType element_type, absl::Span dimensions, + absl::Span expressions = {}); + static absl::StatusOr MakeValidatedShape( PrimitiveType element_type, absl::Span dimensions, - const std::vector& dynamic_dimensions); + const std::vector& dynamic_dimensions, + absl::Span expressions = {}); // Creates a Shape with element type corresponding to T and the given // dimensions @@ -476,7 +485,8 @@ class ShapeUtil { // Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}). static Shape MakeShapeWithDescendingLayout( - PrimitiveType element_type, absl::Span dimensions); + PrimitiveType element_type, absl::Span dimensions, + absl::Span expressions = {}); // Returns a new Shape based on the given Shape with low-dimension-major // layout (i.e. {n, n-1, ..., 0}, like Fortran), and with the dimensions diff --git a/third_party/xla/xla/stream_executor/tpu/c_api_decl.h b/third_party/xla/xla/stream_executor/tpu/c_api_decl.h index 0f3ba256573de1..e06db894249632 100644 --- a/third_party/xla/xla/stream_executor/tpu/c_api_decl.h +++ b/third_party/xla/xla/stream_executor/tpu/c_api_decl.h @@ -247,6 +247,8 @@ typedef struct XLA_Shape { int element_type; Int64List dimensions; BoolList dynamic_dimensions; + Int64List batch_multipliers; + Int64List batch_offsets; struct XLA_Shape* tuple_shapes; // owned int ntuple_shapes; bool has_layout; diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index ca8ba0553bd56a..23b7b9689ebfa4 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -222,6 +222,8 @@ message DebugOptions { // When true, XLA:CPU uses XNNPACK to execute supported operations. bool xla_cpu_use_xnnpack = 359; + string xla_compile_batch_sizes = 399; + // Enabling this will enable optimizations that ignore the possibility of NaN. bool xla_enable_fast_math = 335; @@ -1208,7 +1210,7 @@ message DebugOptions { // Note: when adding a new flag, please add it to one of the hardware-specific // or hardware-agnostic sections at the top of this proto message. - // Next id: 389 + // Next id: 390 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. diff --git a/third_party/xla/xla/xla_data.proto b/third_party/xla/xla/xla_data.proto index 98462a4f6eb83c..798dc12fb5734e 100644 --- a/third_party/xla/xla/xla_data.proto +++ b/third_party/xla/xla/xla_data.proto @@ -361,6 +361,8 @@ message ShapeProto { // The layout used to back this shape. LayoutProto layout = 5; + repeated ExpressionProto expressions = 7; + // Important: if any field is added, be sure to modify ShapeUtil::Equal(), // ShapeUtil::Compatible() and Shape::Hash() appropriately to account for the // new field. @@ -1204,3 +1206,34 @@ message OriginalArrayProto { message OriginalValueProto { repeated OriginalArrayProto leaves = 1; } + +message ExpressionProto { + oneof node_type { + int32 constant_value = 1; // cons + int32 variable_id = 2; // var + AddNode add_node = 3; // exp + exp + SubNode sub_node = 4; // exp - exp + MulNode mul_node = 5; // exp * exp + DivNode div_node = 6; // exp / exp + } +} + +message AddNode { + ExpressionProto lhs = 1; + ExpressionProto rhs = 2; +} + +message SubNode { + ExpressionProto lhs = 1; + ExpressionProto rhs = 2; +} + +message MulNode { + ExpressionProto lhs = 1; + ExpressionProto rhs = 2; +} + +message DivNode { + ExpressionProto lhs = 1; + ExpressionProto rhs = 2; +} \ No newline at end of file