From ab7359b4547da7f05172263f73a18ac552448a85 Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Sat, 4 Apr 2026 15:05:16 +0100 Subject: [PATCH] Add symbolic contents propagation --- .../device_compilation_cluster_signature.cc | 52 +++-- .../device_compilation_cluster_signature.h | 8 +- tensorflow/compiler/jit/flags.cc | 4 + tensorflow/compiler/jit/flags.h | 3 + tensorflow/compiler/jit/kernels/xla_ops.cc | 92 ++++++++- .../compiler/jit/mark_for_compilation_pass.cc | 36 +++- .../compiler/jit/partially_decluster_pass.cc | 16 +- tensorflow/compiler/tf2xla/BUILD | 1 + tensorflow/compiler/tf2xla/graph_compiler.cc | 5 - tensorflow/compiler/tf2xla/kernels/BUILD | 2 + .../compiler/tf2xla/kernels/binary_ops.cc | 125 +++++++----- tensorflow/compiler/tf2xla/kernels/cast_op.cc | 9 + .../compiler/tf2xla/kernels/concat_op.cc | 86 ++++++++- .../compiler/tf2xla/kernels/const_op.cc | 137 +++++++++---- .../compiler/tf2xla/kernels/cwise_ops.cc | 181 ++++++++++++++++++ .../compiler/tf2xla/kernels/cwise_ops.h | 8 + tensorflow/compiler/tf2xla/kernels/fill_op.cc | 11 +- .../compiler/tf2xla/kernels/identity_op.cc | 4 + tensorflow/compiler/tf2xla/kernels/pack_op.cc | 51 ++++- .../compiler/tf2xla/kernels/sequence_ops.cc | 180 +++++++++++++++-- .../compiler/tf2xla/kernels/shape_op.cc | 93 +++++++-- .../compiler/tf2xla/kernels/slice_op.cc | 55 ++++++ .../tf2xla/kernels/strided_slice_op.cc | 81 +++++++- .../compiler/tf2xla/kernels/variable_ops.cc | 19 +- .../compiler/tf2xla/symbolic_content_util.h | 29 +++ tensorflow/compiler/tf2xla/xla_argument.cc | 9 + tensorflow/compiler/tf2xla/xla_argument.h | 8 + tensorflow/compiler/tf2xla/xla_compiler.cc | 20 ++ tensorflow/compiler/tf2xla/xla_expression.cc | 85 +++++--- tensorflow/compiler/tf2xla/xla_expression.h | 15 ++ tensorflow/compiler/tf2xla/xla_op_kernel.cc | 6 + .../core/common_runtime/constant_folding.cc | 89 ++++++++- tensorflow/core/framework/BUILD | 1 + tensorflow/core/framework/tensor_shape.cc | 5 +- .../xla/xla/hlo/builder/xla_builder.cc | 56 +++++- third_party/xla/xla/hlo/builder/xla_builder.h | 8 + third_party/xla/xla/hlo/ir/hlo_instruction.cc | 68 +++++++ third_party/xla/xla/hlo/ir/hlo_instruction.h | 16 ++ .../simplifiers/hlo_constant_folding.cc | 22 ++- .../xla/service/dynamic_constant_rewriter.cc | 65 +++---- third_party/xla/xla/service/hlo.proto | 5 +- third_party/xla/xla/shape.cc | 4 + third_party/xla/xla/shape_dynexpr.h | 6 + 43 files changed, 1541 insertions(+), 235 deletions(-) create mode 100644 tensorflow/compiler/tf2xla/symbolic_content_util.h diff --git a/tensorflow/compiler/jit/device_compilation_cluster_signature.cc b/tensorflow/compiler/jit/device_compilation_cluster_signature.cc index 81902a28532dbf..5996c5f0b3a70c 100644 --- a/tensorflow/compiler/jit/device_compilation_cluster_signature.cc +++ b/tensorflow/compiler/jit/device_compilation_cluster_signature.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/jit/device_compilation_cluster_signature.h" +#include "absl/strings/str_cat.h" #include #include #include @@ -22,14 +23,23 @@ limitations under the License. namespace tensorflow { namespace { using Signature = DeviceCompilationClusterSignature; +using ConstantTensor = Signature::ConstantTensor; using TensorTypeAndShape = Signature::TensorTypeAndShape; // Functor that converts a Signature's arg to a human readable string. struct SignatureHumanStringAppender { explicit SignatureHumanStringAppender(std::string* dest) : dest(dest) {} std::string* dest; - void operator()(const Tensor& arg) { - absl::StrAppend(dest, "; ", arg.DebugString()); + void operator()(const ConstantTensor& arg) { + absl::StrAppend(dest, "; ", arg.value.DebugString()); + if (!arg.contents.empty()) { + absl::StrAppend(dest, " contents=["); + for (int i = 0; i < arg.contents.size(); ++i) { + if (i > 0) absl::StrAppend(dest, ","); + absl::StrAppend(dest, arg.contents[i].DebugString()); + } + absl::StrAppend(dest, "]"); + } } void operator()(const TensorTypeAndShape& arg) { absl::StrAppend(dest, ",", DataTypeString(arg.first)); @@ -40,18 +50,29 @@ struct SignatureHumanStringAppender { // Functor that compares the arg values of two different signatures. Returns // true when the args are not equal. struct SignatureNotEqual { - bool operator()(const Tensor& arg, const Tensor& other) { - return arg.dtype() != other.dtype() || arg.shape() != other.shape() || - arg.tensor_data() != other.tensor_data(); + bool operator()(const ConstantTensor& arg, const ConstantTensor& other) { + if (arg.value.dtype() != other.value.dtype() || + arg.value.shape() != other.value.shape() || + arg.value.tensor_data() != other.value.tensor_data() || + arg.contents.size() != other.contents.size()) { + return true; + } + for (int i = 0; i < arg.contents.size(); ++i) { + if (arg.contents[i].SerializeAsString() != + other.contents[i].SerializeAsString()) { + return true; + } + } + return false; } bool operator()(const TensorTypeAndShape& arg, const TensorTypeAndShape& other) { return arg.first != other.first || arg.second != other.second; } - bool operator()(const Tensor& arg, const TensorTypeAndShape& other) { + bool operator()(const ConstantTensor& arg, const TensorTypeAndShape& other) { return true; } - bool operator()(const TensorTypeAndShape& arg, const Tensor& other) { + bool operator()(const TensorTypeAndShape& arg, const ConstantTensor& other) { return true; } }; @@ -61,12 +82,16 @@ struct SignatureNotEqual { struct SignatureHashCombiner { explicit SignatureHashCombiner(const uint64 h) : h(h) {} uint64 h; - uint64 operator()(const Tensor& arg) { - h = Hash64Combine(h, std::hash()(static_cast(arg.dtype()))); + uint64 operator()(const ConstantTensor& arg) { + h = Hash64Combine(h, std::hash()(static_cast(arg.value.dtype()))); h = Hash64Combine( - h, Hash64(arg.tensor_data().data(), arg.tensor_data().size())); - for (int dim = 0; dim < arg.dims(); ++dim) { - h = Hash64Combine(h, std::hash()(arg.dim_size(dim))); + h, Hash64(arg.value.tensor_data().data(), arg.value.tensor_data().size())); + for (int dim = 0; dim < arg.value.dims(); ++dim) { + h = Hash64Combine(h, std::hash()(arg.value.dim_size(dim))); + } + for (const xla::ExpressionProto& expr : arg.contents) { + std::string serialized = expr.SerializeAsString(); + h = Hash64Combine(h, Hash64(serialized.data(), serialized.size())); } return h; } @@ -120,7 +145,8 @@ absl::StatusOr Signature::Build( switch (arg.kind) { case XlaCompiler::Argument::kConstant: case XlaCompiler::Argument::kConstantResource: - signature.args.push_back(arg.constant_value); + signature.args.push_back( + ConstantTensor{arg.constant_value, arg.constant_value_expressions}); break; case XlaCompiler::Argument::kParameter: case XlaCompiler::Argument::kResource: diff --git a/tensorflow/compiler/jit/device_compilation_cluster_signature.h b/tensorflow/compiler/jit/device_compilation_cluster_signature.h index 4acea2a03c2cb4..da2de8de370842 100644 --- a/tensorflow/compiler/jit/device_compilation_cluster_signature.h +++ b/tensorflow/compiler/jit/device_compilation_cluster_signature.h @@ -18,8 +18,10 @@ limitations under the License. #include #include +#include #include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" namespace tensorflow { @@ -34,7 +36,11 @@ struct DeviceCompilationClusterSignature { // argument number. Tensors must be in host memory. using TensorTypeAndShape = std::pair>; - absl::InlinedVector, 8> args; + struct ConstantTensor { + Tensor value; + std::vector contents; + }; + absl::InlinedVector, 8> args; bool operator==(const DeviceCompilationClusterSignature& other) const; diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index 89ee40b0f58e9a..5a6f741a01e972 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -167,6 +167,9 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector* flag_list) { Flag("tf_xla_enable_dynamic_sizes", &mark_for_compilation_flags->tf_xla_enable_dynamic_sizes, "Enable dynamic sizes support."), + Flag("tf_xla_enable_symbolic_content", + &mark_for_compilation_flags->tf_xla_enable_symbolic_content, + "Enable symbolic content propagation."), Flag("tf_xla_persistent_cache_directory", &mark_for_compilation_flags->tf_xla_persistent_cache_directory, "If non-empty, JIT-compiled executables are saved to and loaded " @@ -262,6 +265,7 @@ void AllocateAndParseFlags() { ->tf_xla_disable_resource_variable_safety_checks_for_debugging = false; mark_for_compilation_flags->tf_xla_deterministic_cluster_names = false; mark_for_compilation_flags->tf_xla_enable_dynamic_sizes = false; + mark_for_compilation_flags->tf_xla_enable_symbolic_content = false; mark_for_compilation_flags->tf_xla_persistent_cache_directory = ""; mark_for_compilation_flags->tf_xla_persistent_cache_device_types = ""; mark_for_compilation_flags->tf_xla_persistent_cache_read_only = false; diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index 971dd8a7a38229..480e20dc7ec296 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -102,6 +102,9 @@ struct MarkForCompilationPassFlags { // If true enables support of dynamic sizes. bool tf_xla_enable_dynamic_sizes; + // If true enables symbolic content propagation. + bool tf_xla_enable_symbolic_content; + // If non-empty, JIT-compiled executables are saved to and loaded from the // specified file system directory path. std::string tf_xla_persistent_cache_directory; diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 782ac342fc8991..aa28b5ac899fee 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -65,6 +65,7 @@ limitations under the License. #include "xla/pjrt/pjrt_client.h" #include "xla/printer.h" #include "xla/service/gpu/gpu_executable_run_options.h" +#include "xla/shape_dynexpr.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/framework/allocator.h" @@ -384,7 +385,6 @@ GetXlaCompilerArgsAndSnapshotVariables( return result; } - std::unique_ptr ExprFromProto(const ExpressionProto& proto) { switch (proto.node_type_case()) { case ExpressionProto::kConstantValue: @@ -448,7 +448,6 @@ static xla::DExpr DimExprToDExpr(const DimExpr* e) { return xla::DExpr::Unknown(); } - absl::Status CompileToLocalExecutable( OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars, const XlaPlatformInfo& platform_info, @@ -511,6 +510,71 @@ absl::Status CompileToLocalExecutable( XlaBatchMatcher* xla_batch_matcher = xla_device_compiler->xla_batch_matcher(); std::optional dynamic_dim_expr; + auto maybe_attach_shape_contents_from_attrs = + [&](int arg_index, const auto& attr_map, + const std::string& node_name) { + auto& arg = norm_args[arg_index]; + if (arg.kind != XlaCompiler::Argument::kConstant) { + return; + } + + bool has_dynamic = false; + auto has_dynamic_it = attr_map.find("has_dynamic"); + if (has_dynamic_it == attr_map.end()) { + return; + } + has_dynamic = has_dynamic_it->second.b(); + if (!has_dynamic) { + return; + } + + auto inferred_shape_it = attr_map.find("user_inferred_shape"); + if (inferred_shape_it == attr_map.end()) { + LOG(INFO) << "XlaCompileOp saw has_dynamic for const arg " + << arg_index << " node=" << node_name + << " but no user_inferred_shape attr"; + return; + } + + TensorShapeProto inferred_shape_proto; + inferred_shape_proto = inferred_shape_it->second.shape(); + + TensorShape inferred_shape(inferred_shape_proto); + if (!TensorShapeUtils::IsVector(arg.constant_value.shape()) || + arg.constant_value.NumElements() != inferred_shape.dims()) { + LOG(INFO) << "XlaCompileOp const arg " << arg_index + << " node=" << node_name + << " has dynamic shape metadata but tensor shape " + << arg.constant_value.shape().DebugString() + << " does not match inferred rank " << inferred_shape.dims(); + return; + } + + arg.constant_value_expressions.clear(); + arg.constant_value_expressions.reserve(inferred_shape.dims()); + for (int64_t i = 0; i < inferred_shape.dims(); ++i) { + xla::ExpressionProto expr; + const xla::DExpr& dim_expr = inferred_shape.get_expression(i); + if (dim_expr && dim_expr->is_dynamic()) { + dim_expr->to_proto(&expr); + } else if (arg.constant_value.dtype() == DT_INT32) { + expr.set_constant_value(arg.constant_value.flat()(i)); + } else if (arg.constant_value.dtype() == DT_INT64) { + expr.set_constant_value(arg.constant_value.flat()(i)); + } else { + LOG(INFO) << "XlaCompileOp const arg " << arg_index + << " node=" << node_name + << " has unsupported dtype for inferred shape contents: " + << DataTypeString(arg.constant_value.dtype()); + arg.constant_value_expressions.clear(); + return; + } + arg.constant_value_expressions.push_back(std::move(expr)); + } + LOG(INFO) << "XlaCompileOp recovered " << arg.constant_value_expressions.size() + << " constant_value_expressions for const arg " << arg_index + << " node=" << node_name << " from user_inferred_shape"; + }; auto record_dynamic_dim_value = [&](int64_t dim_size, xla::DExpr expr) { if (!saw_dynamic_dim_value) { saw_dynamic_dim_value = true; @@ -536,6 +600,7 @@ absl::Status CompileToLocalExecutable( VLOG(1) << "XlaCompileOp retrieved shape-derived marker for arg " << arg_index << " node=" << node_name; } + maybe_attach_shape_contents_from_attrs(arg_index, attr_map, node_name); // Special case for _dynamic_dim... auto dyn_dim_attr = attr_map.find("_dynamic_dim"); @@ -632,6 +697,21 @@ absl::Status CompileToLocalExecutable( return; } + auto set_constant_contents = [&](int rewrite_index) { + arg.constant_value_expressions.clear(); + const int64_t num_elements = arg.constant_value.NumElements(); + arg.constant_value_expressions.reserve(num_elements); + for (int64_t i = 0; i < num_elements; ++i) { + xla::ExpressionProto expr; + if (i == rewrite_index) { + dynamic_dim_expr->to_proto(&expr); + } else { + expr.set_constant_value(arg.constant_value.flat()(i)); + } + arg.constant_value_expressions.push_back(std::move(expr)); + } + }; + if (arg.constant_value.dtype() == DT_INT32) { auto flat = arg.constant_value.flat(); int rewrite_index = -1; @@ -650,9 +730,8 @@ absl::Status CompileToLocalExecutable( VLOG(1) << "XlaCompileOp int32 constant arg " << arg_index << " index " << rewrite_index << " matches dynamic_dim_value=" << dynamic_dim_value; - arg.dynamic_constant_index = rewrite_index; - arg.dynamic_constant_expr = dynamic_dim_expr; mutable_flat(rewrite_index) = filled_batch; + set_constant_contents.template operator()(rewrite_index); } } else if (arg.constant_value.dtype() == DT_INT64) { auto flat = arg.constant_value.flat(); @@ -670,9 +749,8 @@ absl::Status CompileToLocalExecutable( VLOG(1) << "XlaCompileOp int64 constant arg " << arg_index << " index " << rewrite_index << " matches dynamic_dim_value=" << dynamic_dim_value; - arg.dynamic_constant_index = rewrite_index; - arg.dynamic_constant_expr = dynamic_dim_expr; mutable_flat(rewrite_index) = filled_batch; + set_constant_contents.template operator()(rewrite_index); } } }; @@ -687,7 +765,7 @@ absl::Status CompileToLocalExecutable( TensorShape& shp = std::get(norm_args[i].shape); for (int j = 0; j < shp.get_expressions().size(); ++j) { auto e = shp.get_expression(j); - if (e->is_dynamic()) { + if (e && e->is_dynamic()) { int64_t old = shp.dim_size(j); old_vars.push_back({i, j, old}); xla::DExpr padded_expr = xla::DExpr::Const(filled_batch); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 54d0b51c78bc44..566bab23a11867 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -923,15 +923,45 @@ absl::StatusOr MarkForCompilationPassImpl::Initialize() { if (debug_options_.enable_dynamic_sizes) { LogExpressionsViaGraphProperties(*graph_); TF_RETURN_IF_ERROR(AssignDimVars()); + auto has_dynamic_input_expression = [&](const Node* n) { + for (const Edge* edge : n->in_edges()) { + if (edge->IsControlEdge()) { + continue; + } + const Node* src = edge->src(); + auto it = expr_map.find(src->name()); + if (it == expr_map.end()) { + continue; + } + const int output_index = edge->src_output(); + if (output_index < 0 || output_index >= it->second.size()) { + continue; + } + for (const auto& expr_ptr : it->second[output_index]) { + if (expr_ptr == nullptr) { + continue; + } + xla::DExpr dyn = DimExprToDExpr(expr_ptr.get()); + if (dyn && dyn->is_dynamic()) { + return true; + } + } + } + return false; + }; for (Node* n : graph_->op_nodes()) { bool mark_shape_derived = false; if (n->type_string() == "Shape" || n->type_string() == "ShapeN") { - mark_shape_derived = true; + mark_shape_derived = has_dynamic_input_expression(n); } else if (n->type_string() == "Cast") { for (const Edge* edge : n->in_edges()) { - if (edge->IsControlEdge()) continue; + if (edge->IsControlEdge()) { + continue; + } const Node* src = edge->src(); - if (src->type_string() == "Shape" || src->type_string() == "ShapeN") { + if ((src->type_string() == "Shape" || + src->type_string() == "ShapeN") && + has_dynamic_input_expression(src)) { mark_shape_derived = true; break; } diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc index 442635c9a29696..ac6e0ee0177242 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -19,8 +19,10 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/device_util.h" +#include "tensorflow/compiler/jit/encapsulate_util.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" +#include "tensorflow/compiler/tf2xla/symbolic_content_util.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/function.h" @@ -47,6 +49,10 @@ absl::Status FindNodesToDecluster(const Graph& graph, MemoryTypeVector input_mtypes, output_mtypes; for (Node* n : post_order) { + if (SymbolicContentEnabled() && + n->attrs().FindByString(kXlaShapeDerivedAttrName) != nullptr) { + continue; + } std::optional from_cluster = GetXlaClusterForNode(*n); if (!from_cluster) { continue; @@ -308,6 +314,10 @@ absl::Status PartiallyDeclusterGraph(Graph* graph, if (!compile_time_const_nodes[n->id()]) { continue; } + if (SymbolicContentEnabled() && + n->attrs().FindByString(kXlaShapeDerivedAttrName) != nullptr) { + continue; + } absl::string_view cluster_name = *GetXlaClusterForNode(*n); bool node_on_cluster_edge = @@ -379,6 +389,10 @@ absl::Status PartiallyDeclusterGraph(Graph* graph) { if (!IsShapeConsumerOp(*n)) { continue; } + if (SymbolicContentEnabled() && + n->attrs().FindByString(kXlaShapeDerivedAttrName) != nullptr) { + continue; + } std::optional cluster = GetXlaClusterForNode(*n); if (!cluster.has_value()) { @@ -393,8 +407,6 @@ absl::Status PartiallyDeclusterGraph(Graph* graph) { continue; } - VLOG(2) << "Declustering " << n->name() - << " because it is a root shape consumer"; RemoveFromXlaCluster(n); } return absl::OkStatus(); diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 0a800512aa3cb7..c571f73beff0be 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -876,6 +876,7 @@ cc_library( hdrs = [ "literal_util.h", "shape_util.h", + "symbolic_content_util.h", "type_util.h", ], visibility = [":friends"], diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 9d6a1752a34488..f23c423fbb2632 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -83,11 +83,6 @@ absl::Status PrepareArguments( case XlaExpression::Kind::kConstant: arg.kind = XlaCompiler::Argument::kConstant; arg.constant_value = *expressions[i]->constant_value(); - if (expressions[i]->dynamic_constant_index().has_value()) { - arg.dynamic_constant_index = - *expressions[i]->dynamic_constant_index(); - arg.dynamic_constant_expr = expressions[i]->dynamic_constant_expr(); - } break; case XlaExpression::Kind::kXlaOp: if (arg_must_be_compile_time_constant[i]) { diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 25286d7bea51a8..37d68dbd23c242 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -1504,6 +1504,7 @@ tf_kernel_library( name = "slice_op", srcs = ["slice_op.cc"], deps = [ + "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -2211,6 +2212,7 @@ tf_kernel_library( deps = [ ":shape_util", ":tensor_list_utils", + "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index e9f571d830d619..374950c7110f1e 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -42,7 +42,7 @@ namespace { // A subclass of a XlaBinaryOp must build the computation that // describes the (tensor,tensor)->tensor function to apply to each element of // the input. -#define XLA_MAKE_BINARY(NAME, HLO) \ +#define XLA_MAKE_BINARY(NAME, HLO, SYMBOLIC_HLO) \ class NAME##Op : public XlaBinaryOp { \ public: \ explicit NAME##Op(OpKernelConstruction* ctx) : XlaBinaryOp(ctx) {} \ @@ -59,17 +59,26 @@ namespace { (void)extend_dimensions; \ return HLO; \ } \ + xla::DExpr SymbolicComputation(const xla::DExpr& lhs, \ + const xla::DExpr& rhs) override { \ + return SYMBOLIC_HLO; \ + } \ }; \ REGISTER_XLA_OP(Name(#NAME), NAME##Op) -XLA_MAKE_BINARY(Add, xla::Add(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(AddV2, xla::Add(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Sub, xla::Sub(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Mul, xla::Mul(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Div, xla::Div(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Add, xla::Add(lhs, rhs, extend_dimensions), + (lhs + rhs).simplify()); +XLA_MAKE_BINARY(AddV2, xla::Add(lhs, rhs, extend_dimensions), + (lhs + rhs).simplify()); +XLA_MAKE_BINARY(Sub, xla::Sub(lhs, rhs, extend_dimensions), + (lhs - rhs).simplify()); +XLA_MAKE_BINARY(Mul, xla::Mul(lhs, rhs, extend_dimensions), + (lhs * rhs).simplify()); +XLA_MAKE_BINARY(Div, xla::Div(lhs, rhs, extend_dimensions), + (lhs / rhs).simplify()); -XLA_MAKE_BINARY(Atan2, xla::Atan2(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Atan2, xla::Atan2(lhs, rhs, extend_dimensions), xla::DExpr()); +XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions), xla::DExpr()); // Implementation of DivNoNan. Pseudo-code: // if (y == 0) { @@ -87,7 +96,8 @@ static xla::XlaOp DivNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, return result; } XLA_MAKE_BINARY(DivNoNan, - DivNoNanImpl(b, input_type(0), lhs, rhs, broadcast_helper)); + DivNoNanImpl(b, input_type(0), lhs, rhs, broadcast_helper), + xla::DExpr()); // Implementation of MulNoNan. Pseudo-code: // if (y == 0) { @@ -105,7 +115,8 @@ static xla::XlaOp MulNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, return result; } XLA_MAKE_BINARY(MulNoNan, - MulNoNanImpl(b, input_type(0), lhs, rhs, broadcast_helper)); + MulNoNanImpl(b, input_type(0), lhs, rhs, broadcast_helper), + xla::DExpr()); // Implementation of FloorDiv. // @@ -144,7 +155,8 @@ static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, return xla::Select(round_down, xla::Sub(x_div_y, one), x_div_y); } XLA_MAKE_BINARY(FloorDiv, - FloorDivImpl(b, input_type(0), lhs, rhs, broadcast_helper)); + FloorDivImpl(b, input_type(0), lhs, rhs, broadcast_helper), + (*lhs / *rhs)->s()); xla::XlaOp XlogyImpl(xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { @@ -153,7 +165,7 @@ xla::XlaOp XlogyImpl(xla::XlaOp x, xla::XlaOp y, auto is_zero = xla::Eq(x, zero); return xla::Select(is_zero, zero, xla::Mul(x, xla::Log(y))); } -XLA_MAKE_BINARY(Xlogy, XlogyImpl(lhs, rhs, broadcast_helper)); +XLA_MAKE_BINARY(Xlogy, XlogyImpl(lhs, rhs, broadcast_helper), nullptr); xla::XlaOp Xlog1pyImpl(xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { @@ -163,7 +175,7 @@ xla::XlaOp Xlog1pyImpl(xla::XlaOp x, xla::XlaOp y, auto x_is_zero = xla::Eq(x, zero); return xla::Select(x_is_zero, zero, non_zero); } -XLA_MAKE_BINARY(Xlog1py, Xlog1pyImpl(lhs, rhs, broadcast_helper)); +XLA_MAKE_BINARY(Xlog1py, Xlog1pyImpl(lhs, rhs, broadcast_helper), nullptr); xla::XlaOp XdivyImpl(xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { @@ -172,7 +184,7 @@ xla::XlaOp XdivyImpl(xla::XlaOp x, xla::XlaOp y, auto is_zero = xla::Eq(x, zero); return xla::Select(is_zero, zero, xla::Div(x, y)); } -XLA_MAKE_BINARY(Xdivy, XdivyImpl(lhs, rhs, broadcast_helper)); +XLA_MAKE_BINARY(Xdivy, XdivyImpl(lhs, rhs, broadcast_helper), nullptr); // Implementation of FloorMod. Pseudo-code: // T trunc_mod = std::fmod(x, y); @@ -189,34 +201,41 @@ static xla::XlaOp FloorModImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, return xla::Select(do_plus, xla::Add(trunc_mod, y), trunc_mod); } XLA_MAKE_BINARY(FloorMod, - FloorModImpl(b, input_type(0), lhs, rhs, broadcast_helper)); + FloorModImpl(b, input_type(0), lhs, rhs, broadcast_helper), + nullptr); -XLA_MAKE_BINARY(BitwiseAnd, xla::And(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(BitwiseOr, xla::Or(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(BitwiseXor, xla::Xor(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(BitwiseAnd, xla::And(lhs, rhs, extend_dimensions), nullptr); +XLA_MAKE_BINARY(BitwiseOr, xla::Or(lhs, rhs, extend_dimensions), nullptr); +XLA_MAKE_BINARY(BitwiseXor, xla::Xor(lhs, rhs, extend_dimensions), nullptr); -XLA_MAKE_BINARY(LeftShift, xla::ShiftLeft(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(LeftShift, xla::ShiftLeft(lhs, rhs, extend_dimensions), + nullptr); XLA_MAKE_BINARY(RightShift, (DataTypeIsUnsigned(ctx->input_type(0)) ? xla::ShiftRightLogical(lhs, rhs, extend_dimensions) - : xla::ShiftRightArithmetic(lhs, rhs, extend_dimensions))); - -XLA_MAKE_BINARY(LogicalAnd, xla::And(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(LogicalOr, xla::Or(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Mod, xla::Rem(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Maximum, xla::Max(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Minimum, xla::Min(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(RealDiv, xla::Div(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(ReciprocalGrad, xla::Neg(xla::Mul(rhs, xla::Mul(lhs, lhs)))); + : xla::ShiftRightArithmetic(lhs, rhs, extend_dimensions)), + nullptr); + +XLA_MAKE_BINARY(LogicalAnd, xla::And(lhs, rhs, extend_dimensions), nullptr); +XLA_MAKE_BINARY(LogicalOr, xla::Or(lhs, rhs, extend_dimensions), nullptr); +XLA_MAKE_BINARY(Mod, xla::Rem(lhs, rhs, extend_dimensions), nullptr); +XLA_MAKE_BINARY(Maximum, xla::Max(lhs, rhs, extend_dimensions), nullptr); +XLA_MAKE_BINARY(Minimum, xla::Min(lhs, rhs, extend_dimensions), nullptr); +XLA_MAKE_BINARY(RealDiv, xla::Div(lhs, rhs, extend_dimensions), + (*lhs / *rhs)->s()); +XLA_MAKE_BINARY(ReciprocalGrad, xla::Neg(xla::Mul(rhs, xla::Mul(lhs, lhs))), + nullptr); XLA_MAKE_BINARY( RsqrtGrad, xla::Mul((lhs * lhs) * lhs, xla::Div(rhs, XlaHelpers::IntegerLiteral(b, input_type(0), -2)), - extend_dimensions)); + extend_dimensions), + nullptr); XLA_MAKE_BINARY( SqrtGrad, xla::Div(xla::Mul(rhs, XlaHelpers::FloatLiteral(b, input_type(0), 0.5)), - lhs, extend_dimensions)); + lhs, extend_dimensions), + nullptr); // Implementation of TruncateDiv. // @@ -235,35 +254,39 @@ static xla::XlaOp TruncateDivImpl(xla::XlaBuilder* b, DataType dtype, return xla::Select(round_up, xla::Ceil(x_div_y), xla::Floor(x_div_y)); } XLA_MAKE_BINARY(TruncateDiv, - TruncateDivImpl(b, input_type(0), lhs, rhs, broadcast_helper)); -XLA_MAKE_BINARY(TruncateMod, xla::Rem(lhs, rhs, extend_dimensions)); + TruncateDivImpl(b, input_type(0), lhs, rhs, broadcast_helper), + (*lhs / *rhs)->s()); +XLA_MAKE_BINARY(TruncateMod, xla::Rem(lhs, rhs, extend_dimensions), nullptr); // Comparison ops -XLA_MAKE_BINARY(Equal, xla::Eq(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(NotEqual, xla::Ne(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Greater, xla::Gt(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(GreaterEqual, xla::Ge(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Less, xla::Lt(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(LessEqual, xla::Le(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Equal, xla::Eq(lhs, rhs, extend_dimensions), nullptr); +XLA_MAKE_BINARY(NotEqual, xla::Ne(lhs, rhs, extend_dimensions), nullptr); +XLA_MAKE_BINARY(Greater, xla::Gt(lhs, rhs, extend_dimensions), nullptr); +XLA_MAKE_BINARY(GreaterEqual, xla::Ge(lhs, rhs, extend_dimensions), nullptr); +XLA_MAKE_BINARY(Less, xla::Lt(lhs, rhs, extend_dimensions), nullptr); +XLA_MAKE_BINARY(LessEqual, xla::Le(lhs, rhs, extend_dimensions), nullptr); // Non-linear ops XLA_MAKE_BINARY(SigmoidGrad, xla::Mul(xla::Mul(rhs, lhs), - xla::Sub(XlaHelpers::One(b, input_type(0)), lhs))); + xla::Sub(XlaHelpers::One(b, input_type(0)), lhs)), + nullptr); -XLA_MAKE_BINARY(SoftplusGrad, xla::Mul(lhs, xla::Logistic(rhs))); +XLA_MAKE_BINARY(SoftplusGrad, xla::Mul(lhs, xla::Logistic(rhs)), nullptr); // softsigngrad(gradients, features) = gradients / (1 + abs(features)) ** 2 XLA_MAKE_BINARY(SoftsignGrad, xla::Div(lhs, xla::Square(xla::Add(XlaHelpers::One(b, input_type(0)), - xla::Abs(rhs))))); + xla::Abs(rhs)))), + nullptr); XLA_MAKE_BINARY(TanhGrad, xla::Mul(rhs, xla::Sub(XlaHelpers::One(b, input_type(0)), - xla::Mul(lhs, lhs)))); + xla::Mul(lhs, lhs))), + nullptr); -XLA_MAKE_BINARY(Pow, xla::Pow(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Pow, xla::Pow(lhs, rhs, extend_dimensions), nullptr); xla::XlaOp SquaredDifferenceImpl( DataType dtype, xla::XlaOp x, xla::XlaOp y, @@ -277,7 +300,8 @@ xla::XlaOp SquaredDifferenceImpl( } XLA_MAKE_BINARY(SquaredDifference, SquaredDifferenceImpl(input_type(0), lhs, rhs, - extend_dimensions)); + extend_dimensions), + nullptr); xla::XlaOp IgammaImpl(xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { @@ -285,7 +309,7 @@ xla::XlaOp IgammaImpl(xla::XlaOp x, xla::XlaOp y, return xla::Igamma(x, y); } -XLA_MAKE_BINARY(Igamma, IgammaImpl(lhs, rhs, broadcast_helper)); +XLA_MAKE_BINARY(Igamma, IgammaImpl(lhs, rhs, broadcast_helper), nullptr); xla::XlaOp IgammaGradAImpl(xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { @@ -293,7 +317,8 @@ xla::XlaOp IgammaGradAImpl(xla::XlaOp x, xla::XlaOp y, return xla::IgammaGradA(x, y); } -XLA_MAKE_BINARY(IgammaGradA, IgammaGradAImpl(lhs, rhs, broadcast_helper)); +XLA_MAKE_BINARY(IgammaGradA, IgammaGradAImpl(lhs, rhs, broadcast_helper), + nullptr); xla::XlaOp RandomGammaGradImpl(xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { @@ -302,7 +327,7 @@ xla::XlaOp RandomGammaGradImpl(xla::XlaOp x, xla::XlaOp y, } XLA_MAKE_BINARY(RandomGammaGrad, - RandomGammaGradImpl(lhs, rhs, broadcast_helper)); + RandomGammaGradImpl(lhs, rhs, broadcast_helper), nullptr); xla::XlaOp IgammacImpl(xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { @@ -310,7 +335,7 @@ xla::XlaOp IgammacImpl(xla::XlaOp x, xla::XlaOp y, return xla::Igammac(x, y); } -XLA_MAKE_BINARY(Igammac, IgammacImpl(lhs, rhs, broadcast_helper)); +XLA_MAKE_BINARY(Igammac, IgammacImpl(lhs, rhs, broadcast_helper), nullptr); xla::XlaOp PolygammaImpl(xla::XlaOp n, xla::XlaOp x, const BCast& broadcast_helper) { @@ -318,14 +343,14 @@ xla::XlaOp PolygammaImpl(xla::XlaOp n, xla::XlaOp x, return xla::Polygamma(n, x); } -XLA_MAKE_BINARY(Polygamma, PolygammaImpl(lhs, rhs, broadcast_helper)); +XLA_MAKE_BINARY(Polygamma, PolygammaImpl(lhs, rhs, broadcast_helper), nullptr); xla::XlaOp ZetaImpl(xla::XlaOp x, xla::XlaOp q, const BCast& broadcast_helper) { std::tie(x, q) = XlaBinaryOp::Broadcast(x, q, broadcast_helper); return xla::Zeta(x, q); } -XLA_MAKE_BINARY(Zeta, ZetaImpl(lhs, rhs, broadcast_helper)); +XLA_MAKE_BINARY(Zeta, ZetaImpl(lhs, rhs, broadcast_helper), nullptr); #undef XLA_MAKE_BINARY diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc index 1779cfcc1ced40..74e9c5aa4003ff 100644 --- a/tensorflow/compiler/tf2xla/kernels/cast_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc @@ -92,6 +92,15 @@ class CastOp : public XlaOpKernel { output = xla::ConvertElementType(input, dst_type_); } + const auto& input_contents = ctx->InputExpression(0).contents(); + if (!input_contents.empty()) { + auto output_expr = + XlaExpression::XlaOp(output, ctx->expected_output_dtype(0)); + output_expr.set_contents(std::vector(input_contents.begin(), + input_contents.end())); + ctx->SetOutputExpression(0, output_expr); + return; + } ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index bed3479941ca41..823ed90e1d624e 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -18,9 +18,11 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/log/log.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/kernels/shape_util.h" +#include "tensorflow/compiler/tf2xla/symbolic_content_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -41,6 +43,44 @@ limitations under the License. namespace tensorflow { namespace { +bool AppendConcatInputContents(const XlaExpression& expr, + const TensorShape& shape, + std::vector* contents) { + const auto& input_contents = expr.contents(); + if (!input_contents.empty()) { + if (shape.dims() == 0) { + if (input_contents.size() != 1) { + return false; + } + contents->push_back(input_contents[0]); + return true; + } + if (shape.dims() == 1) { + contents->insert(contents->end(), input_contents.begin(), + input_contents.end()); + return true; + } + return false; + } + if (shape.dims() == 0) { + contents->push_back(xla::DExpr::Const(xla::kUnknownContentSentinel)); + return true; + } + if (shape.dims() != 1) { + return false; + } + for (int64_t i = 0; i < shape.dim_size(0); ++i) { + contents->push_back(xla::DExpr::Const(xla::kUnknownContentSentinel)); + } + return true; +} + +bool HasDynamicContents(absl::Span contents) { + return absl::c_any_of(contents, [](const xla::DExpr& expr) { + return expr && expr->is_dynamic(); + }); +} + // -------------------------------------------------------------------------- class ConcatBaseOp : public XlaOpKernel { public: @@ -74,6 +114,29 @@ class ConcatBaseOp : public XlaOpKernel { // Make a vector holding the XlaOp for each of the inputs that has non-zero // elements. std::vector input_data; + std::vector> input_contents; + input_contents.resize(N); + bool has_output_contents = SymbolicContentEnabled() && axis == 0; + std::vector output_contents; + if (has_output_contents) { + for (int i = 0; i < N; ++i) { + if (!AppendConcatInputContents(ctx->InputExpression(ValueInputIndex(i)), + shapes[i], + &output_contents)) { + has_output_contents = false; + output_contents.clear(); + break; + } + if (shapes[i].dims() == 0) { + input_contents[i].push_back(output_contents.back()); + } else if (shapes[i].dims() == 1) { + input_contents[i].insert(input_contents[i].end(), + output_contents.end() - shapes[i].dim_size(0), + output_contents.end()); + } + } + has_output_contents = has_output_contents && HasDynamicContents(output_contents); + } int output_concat_dim = 0; for (int i = 0; i < N; ++i) { xla::XlaOp handle = values[i]; @@ -86,7 +149,12 @@ class ConcatBaseOp : public XlaOpKernel { "] = ", in_shape.DebugString())); if (in_shape.dims() == 0) { // Inputs that come in as scalars must be reshaped to 1-vectors. - input_data.push_back(xla::Reshape(handle, {1})); + xla::XlaOp reshaped = xla::Reshape(handle, {1}); + if (has_output_contents) { + OP_REQUIRES_OK( + ctx, ctx->builder()->SetInstructionContents(reshaped, input_contents[i])); + } + input_data.push_back(reshaped); } else { input_data.push_back(handle); } @@ -94,10 +162,24 @@ class ConcatBaseOp : public XlaOpKernel { } VLOG(1) << "Concat dim " << concat_dim << " equivalent to " << axis; - ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), input_data, axis)); + auto output = xla::ConcatInDim(ctx->builder(), input_data, axis); + if (has_output_contents) { + OP_REQUIRES_OK( + ctx, ctx->builder()->SetInstructionContents(output, output_contents)); + auto output_expr = + XlaExpression::XlaOp(output, ctx->expected_output_dtype(0)); + output_expr.set_contents(std::move(output_contents)); + ctx->SetOutputExpression(0, output_expr); + return; + } + ctx->SetOutput(0, output); } private: + int ValueInputIndex(int value_index) const { + return value_index < axis_index_ ? value_index : value_index + 1; + } + int axis_index_; }; diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index faa16509eb2cc2..beb478cd979532 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -22,10 +23,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/shape_dynexpr.h" -#include "tsl/platform/protobuf.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_shape_expr.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/logging.h" namespace tensorflow { namespace { @@ -102,6 +105,57 @@ xla::XlaOp GetScalarConst(const TensorProto& proto, xla::XlaBuilder* b) { return xla::XlaOp(); } +bool IsDynamicExpressionProto(const ExpressionProto& proto) { + switch (proto.node_type_case()) { + case ExpressionProto::kVariableId: + return true; + case ExpressionProto::kAddNode: + return IsDynamicExpressionProto(proto.add_node().lhs()) || + IsDynamicExpressionProto(proto.add_node().rhs()); + case ExpressionProto::kSubNode: + return IsDynamicExpressionProto(proto.sub_node().lhs()) || + IsDynamicExpressionProto(proto.sub_node().rhs()); + case ExpressionProto::kMulNode: + return IsDynamicExpressionProto(proto.mul_node().lhs()) || + IsDynamicExpressionProto(proto.mul_node().rhs()); + case ExpressionProto::kDivNode: + return IsDynamicExpressionProto(proto.div_node().lhs()) || + IsDynamicExpressionProto(proto.div_node().rhs()); + case ExpressionProto::kConstantValue: + case ExpressionProto::NODE_TYPE_NOT_SET: + return false; + } +} + +std::vector BuildShapeContentsFromTensorShapeProto( + const TensorShapeProto& shape) { + std::vector contents; + contents.reserve(shape.dim_size()); + for (int i = 0; i < shape.dim_size(); ++i) { + xla::DExpr expr = + i < shape.expressions_size() ? tensorflow::DExprFromProto(shape.expressions(i)) + : xla::DExpr(); + LOG(INFO) << "BuildShapeContentsFromTensorShape dim=" << i + << " expr=" << expr + << " dynamic=" + << (expr && expr->is_dynamic() ? "true" : "false"); + contents.push_back(expr && expr->is_dynamic() + ? std::move(expr) + : xla::DExpr::Const(xla::kUnknownContentSentinel)); + } + return contents; +} + +int64_t CountDynamicShapeContents(const TensorShapeProto& shape) { + int64_t dynamic_count = 0; + for (int i = 0; i < shape.expressions_size(); ++i) { + if (IsDynamicExpressionProto(shape.expressions(i))) { + ++dynamic_count; + } + } + return dynamic_count; +} + class ConstOp : public XlaOpKernel { public: explicit ConstOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -121,65 +175,66 @@ class ConstOp : public XlaOpKernel { bool has_dynamic = false; TensorShapeProto inferred_shape_proto; - TensorShape shape; if (GetNodeAttr(ctx->op_kernel().def(), "has_dynamic", &has_dynamic).ok() && has_dynamic) { if (GetNodeAttr(ctx->op_kernel().def(), "user_inferred_shape", &inferred_shape_proto) .ok()) { - shape = TensorShape(inferred_shape_proto); + LOG(INFO) << "ConstOp recovered dynamic folded-const metadata with " + << "inferred_shape=" << inferred_shape_proto.DebugString() + << " dynamic_exprs=" + << CountDynamicShapeContents(inferred_shape_proto); } } // To avoid blowups for large constants filled with the same value, // recognize that case and emit a scalar broadcast instead. - shape = has_dynamic ? shape : TensorShape(proto_.tensor_shape()); + TensorShape shape(proto_.tensor_shape()); if (shape.num_elements() > 1) { xla::XlaOp value = GetScalarConst(proto_, b); if (value.valid()) { - ctx->SetOutput(0, xla::Broadcast(value, shape.dim_sizes(), - shape.get_filled_expressions())); + if (has_dynamic) { + LOG(INFO) << "ConstOp broadcast fast path shape=" + << shape.DebugString() << " inferred_rank=" + << inferred_shape_proto.dim_size(); + } + xla::XlaOp broadcast = + xla::Broadcast(value, shape.dim_sizes(), shape.get_expressions()); + XlaExpression output = + XlaExpression::XlaOp(broadcast, ctx->expected_output_dtype(0)); + if (has_dynamic && shape.dims() == 1 && + shape.dim_size(0) == inferred_shape_proto.dim_size()) { + LOG(INFO) << "ConstOp attaching shape contents through broadcast fast " + << "path with " << shape.dim_size(0) + << " entries and dynamic_exprs=" + << CountDynamicShapeContents(inferred_shape_proto); + output.set_contents( + BuildShapeContentsFromTensorShapeProto(inferred_shape_proto)); + } + ctx->SetOutputExpression(0, output); return; } } + Tensor tensor(proto_.dtype()); + OP_REQUIRES(ctx, tensor.FromProto(cpu_allocator(), proto_), + errors::InvalidArgument("Cannot parse tensor from proto: ", + proto_.DebugString())); if (has_dynamic) { - std::vector dimension_constants; - for (int i = 0; i < shape.dims(); ++i) { - if (shape.get_expression(i) && shape.get_expression(i)->is_dynamic()) { - int32_t dim_val = static_cast(shape.dim_size(i)); - xla::XlaOp scalar_const = xla::ConstantR0(b, dim_val); - xla::ExpressionProto expr_proto; - shape.get_expression(i)->to_proto(&expr_proto); - std::string expr_textproto = - tsl::LegacyUnredactedShortDebugString(expr_proto); - VLOG(1) << "ConstOp:expr_textproto is " << expr_textproto; - OP_REQUIRES_OK( - ctx, b->SetInstructionFrontendAttribute(scalar_const, - "dynamic_constant_index", - "0")); - OP_REQUIRES_OK(ctx, - b->SetInstructionFrontendAttribute( - scalar_const, "dynamic_constant_expr", - expr_textproto)); - dimension_constants.push_back(xla::Reshape(scalar_const, {1})); - } else { - int32_t dim_val = static_cast(shape.dim_size(i)); - xla::XlaOp scalar_const = xla::ConstantR0(b, dim_val); - dimension_constants.push_back(xla::Reshape(scalar_const, {1})); - } - } - - xla::XlaOp combined_shape_constant = xla::ConcatInDim(b, - dimension_constants, 0); - ctx->SetOutput(0, combined_shape_constant); - } else { - Tensor tensor(proto_.dtype()); - OP_REQUIRES(ctx, tensor.FromProto(cpu_allocator(), proto_), - errors::InvalidArgument("Cannot parse tensor from proto: ", - proto_.DebugString())); - ctx->SetConstantOutput(0, tensor); + LOG(INFO) << "ConstOp tensor path tensor_shape=" + << tensor.shape().DebugString() << " inferred_rank=" + << inferred_shape_proto.dim_size(); + } + XlaExpression output = XlaExpression::Constant(tensor); + if (has_dynamic && tensor.dims() == 1 && + tensor.dim_size(0) == inferred_shape_proto.dim_size()) { + LOG(INFO) << "ConstOp attaching shape contents to folded const with " + << tensor.dim_size(0) << " entries and dynamic_exprs=" + << CountDynamicShapeContents(inferred_shape_proto); + output.set_contents( + BuildShapeContentsFromTensorShapeProto(inferred_shape_proto)); } + ctx->SetOutputExpression(0, output); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc index 6c91556862d9e2..05c10e6957826e 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc @@ -32,10 +32,182 @@ limitations under the License. #include "xla/shape.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/bcast.h" namespace tensorflow { +namespace { + +bool IsSymbolicContentType(DataType type) { + type = BaseType(type); + return type == DT_INT32 || type == DT_INT64; +} + +bool TryGetIntContentsFromConstant(const Tensor& tensor, + std::vector* contents) { + contents->clear(); + if (tensor.dims() > 1) { + return false; + } + if (tensor.dtype() == DT_INT32) { + if (tensor.dims() == 0) { + contents->push_back(xla::DExpr::Const(tensor.scalar()())); + return true; + } + auto flat = tensor.flat(); + contents->reserve(flat.size()); + for (int i = 0; i < flat.size(); ++i) { + contents->push_back(xla::DExpr::Const(flat(i))); + } + return true; + } + if (tensor.dtype() == DT_INT64) { + if (tensor.dims() == 0) { + contents->push_back(xla::DExpr::Const(tensor.scalar()())); + return true; + } + auto flat = tensor.flat(); + contents->reserve(flat.size()); + for (int i = 0; i < flat.size(); ++i) { + contents->push_back(xla::DExpr::Const(flat(i))); + } + return true; + } + return false; +} + +bool TryGetIntContentsFromLiteral(const xla::LiteralSlice& literal, + std::vector* contents) { + contents->clear(); + if (literal.shape().dimensions_size() > 1) { + return false; + } + if (literal.shape().element_type() == xla::S32) { + if (literal.shape().dimensions_size() == 0) { + contents->push_back(xla::DExpr::Const(literal.Get({}))); + return true; + } + const int64_t size = literal.shape().dimensions(0); + contents->reserve(size); + for (int64_t i = 0; i < size; ++i) { + contents->push_back(xla::DExpr::Const(literal.Get({i}))); + } + return true; + } + if (literal.shape().element_type() == xla::S64) { + if (literal.shape().dimensions_size() == 0) { + contents->push_back(xla::DExpr::Const(literal.Get({}))); + return true; + } + const int64_t size = literal.shape().dimensions(0); + contents->reserve(size); + for (int64_t i = 0; i < size; ++i) { + contents->push_back(xla::DExpr::Const(literal.Get({i}))); + } + return true; + } + return false; +} + +bool TryGetInputContents(XlaOpKernelContext* ctx, const XlaExpression& expr, + const TensorShape& shape, + std::vector* contents) { + if (shape.dims() > 1) { + return false; + } + if (!expr.contents().empty()) { + if (absl::c_any_of(expr.contents(), + [](const xla::DExpr& e) { return !e; })) { + return false; + } + contents->assign(expr.contents().begin(), expr.contents().end()); + return true; + } + auto constant = expr.constant_value(); + if (!constant.has_value()) { + if (!expr.handle().valid() || expr.handle().IsUninitialized()) { + return false; + } + auto literal_or = + ctx->value_inference().AnalyzeConstant(expr.handle(), + xla::ValueInferenceMode::kValue); + if (!literal_or.ok() || !literal_or->AllValid()) { + return false; + } + auto literal = literal_or->GetValue(); + if (!literal.has_value()) { + return false; + } + return TryGetIntContentsFromLiteral(*literal, contents); + } + return TryGetIntContentsFromConstant(*constant, contents); +} + +xla::DExpr BroadcastedContentAt(absl::Span contents, + const TensorShape& shape, + int64_t output_index) { + if (shape.dims() == 0) { + return contents.empty() ? xla::DExpr() : contents[0]; + } + if (contents.empty()) { + return xla::DExpr(); + } + if (shape.dim_size(0) == 1) { + return contents[0]; + } + if (output_index >= contents.size()) { + return xla::DExpr(); + } + return contents[output_index]; +} + +bool TryBuildSymbolicBinaryContents(XlaOpKernelContext* ctx, + XlaBinaryOp* op, + const TensorShape& lhs_shape, + const TensorShape& rhs_shape, + const BCast& bcast, + std::vector* contents) { + contents->clear(); + if (!IsSymbolicContentType(ctx->input_type(0)) || + !IsSymbolicContentType(ctx->expected_output_dtype(0))) { + return false; + } + const auto& output_shape = bcast.output_shape(); + if (output_shape.size() > 1) { + return false; + } + + std::vector lhs_contents; + std::vector rhs_contents; + if (!TryGetInputContents(ctx, ctx->InputExpression(0), lhs_shape, + &lhs_contents) || + !TryGetInputContents(ctx, ctx->InputExpression(1), rhs_shape, + &rhs_contents)) { + return false; + } + + int64_t output_elements = output_shape.empty() ? 1 : output_shape[0]; + contents->reserve(output_elements); + for (int64_t i = 0; i < output_elements; ++i) { + xla::DExpr lhs_expr = BroadcastedContentAt(lhs_contents, lhs_shape, i); + xla::DExpr rhs_expr = BroadcastedContentAt(rhs_contents, rhs_shape, i); + if (!lhs_expr || !rhs_expr) { + contents->clear(); + return false; + } + xla::DExpr out_expr = op->SymbolicComputation(lhs_expr, rhs_expr); + if (!out_expr) { + contents->clear(); + return false; + } + contents->push_back(out_expr.simplify()); + } + return true; +} + +} // namespace + void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) { TensorShape lhs_shape = ctx->InputShape(0); TensorShape rhs_shape = ctx->InputShape(1); @@ -192,6 +364,15 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) { // The TensorFlow helper computed the post-broadcast shape in // output_shape: we rely on subclassed Computations to implement the // same broadcast semantics. + std::vector output_contents; + if (TryBuildSymbolicBinaryContents(ctx, this, lhs_shape, rhs_shape, bcast, + &output_contents)) { + auto output_expr = + XlaExpression::XlaOp(output, ctx->expected_output_dtype(0)); + output_expr.set_contents(std::move(output_contents)); + ctx->SetOutputExpression(0, output_expr); + return; + } ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h index d22e6eb74039b4..df92dac718db13 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h @@ -67,6 +67,14 @@ class XlaBinaryOp : public XlaOpKernel { const absl::Span& rhs_shape, const BCast& broadcast_helper, const std::vector& extend_dimensions) = 0; + // Returns a symbolic expression for one output element when content metadata + // should be propagated through this op. Returns an empty DExpr when the operation + // should not propagate symbolic contents. + virtual xla::DExpr SymbolicComputation(const xla::DExpr& lhs, + const xla::DExpr& rhs) { + return xla::DExpr(); + } + void Compile(XlaOpKernelContext* ctx) override; // Helper function that performs the broadcasting described by diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc index 6e5a1430538365..97e4c6e85d4dcf 100644 --- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc @@ -53,11 +53,20 @@ class FillOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector( "dims", &dims, xla::ValueInferenceMode::kUpperBound)); + std::vector dim_exprs; + dim_exprs.reserve(dims.size()); + const auto& contents = ctx->InputExpression("dims").contents(); + for (int64_t i = 0; i < dims.size(); ++i) { + dim_exprs.push_back(i < contents.size() && contents[i] && + contents[i]->is_dynamic() + ? contents[i] + : xla::DExpr::Const(dims[i])); + } std::vector dynamic_dims; OP_REQUIRES_OK( ctx, ctx->ResolveInputDynamismIntoPredVector("dims", &dynamic_dims)); - auto output = xla::Broadcast(ctx->Input("value"), dims); + auto output = xla::Broadcast(ctx->Input("value"), dims, dim_exprs); for (int64_t i = 0; i < dims.size(); ++i) { // If a dimension is dynamic, call set-dimension-size on the output. if (dynamic_dims[i]) { diff --git a/tensorflow/compiler/tf2xla/kernels/identity_op.cc b/tensorflow/compiler/tf2xla/kernels/identity_op.cc index 3e765b853e110d..7b04d501d305de 100644 --- a/tensorflow/compiler/tf2xla/kernels/identity_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/identity_op.cc @@ -32,6 +32,10 @@ class IdentityOp : public XlaOpKernel { for (int i = 0; i < ctx->num_inputs(); ++i) { if (IsTensorListInput(ctx, i)) { ctx->SetTensorListOutput(i, ctx->Input(i)); + } else if (ctx->InputExpression(i).kind() != + XlaExpression::Kind::kResource && + ctx->input_type(i) != DT_VARIANT) { + ctx->SetOutputExpression(i, ctx->InputExpression(i)); } else { DCHECK(ctx->input_type(i) != DT_VARIANT); // Forwards using the underlying op_kernel_context so both tensor and diff --git a/tensorflow/compiler/tf2xla/kernels/pack_op.cc b/tensorflow/compiler/tf2xla/kernels/pack_op.cc index 399d26b6f55de0..9704447b037b35 100644 --- a/tensorflow/compiler/tf2xla/kernels/pack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pack_op.cc @@ -17,6 +17,8 @@ limitations under the License. #include +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/tf2xla/symbolic_content_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/hlo/builder/xla_builder.h" @@ -28,6 +30,33 @@ limitations under the License. namespace tensorflow { namespace { +bool TryBuildPackedContents(XlaOpKernelContext* ctx, int num, int axis, + std::vector* contents) { + contents->clear(); + if (axis != 0) { + return false; + } + for (int i = 0; i < num; ++i) { + if (ctx->InputShape(i).dims() != 0) { + contents->clear(); + return false; + } + const auto& input_contents = ctx->InputExpression(i).contents(); + if (!input_contents.empty()) { + if (input_contents.size() != 1) { + contents->clear(); + return false; + } + contents->push_back(input_contents[0]); + continue; + } + contents->push_back(xla::DExpr::Const(xla::kUnknownContentSentinel)); + } + return absl::c_any_of(*contents, [](const xla::DExpr& expr) { + return expr && expr->is_dynamic(); + }); +} + class PackOp : public XlaOpKernel { public: explicit PackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -67,13 +96,33 @@ class PackOp : public XlaOpKernel { std::vector exprs = child_shape.get_filled_expressions(); child_shape.InsertDim(axis, 1); exprs.insert(exprs.begin() + axis, xla::DExpr::Const(1)); + std::vector output_contents; + const bool has_output_contents = + SymbolicContentEnabled() && + TryBuildPackedContents(ctx, num, axis, &output_contents); for (int i = 0; i < num; ++i) { // Reshape the inputs to have an extra dimension of size 1. reshaped_inputs[i] = xla::Reshape(values[i], child_shape.dim_sizes(), exprs); + if (has_output_contents) { + OP_REQUIRES_OK( + ctx, ctx->builder()->SetInstructionContents( + reshaped_inputs[i], + {output_contents[static_cast(i)]})); + } } - ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), reshaped_inputs, axis)); + auto output = xla::ConcatInDim(ctx->builder(), reshaped_inputs, axis); + if (has_output_contents) { + OP_REQUIRES_OK( + ctx, ctx->builder()->SetInstructionContents(output, output_contents)); + auto output_expr = + XlaExpression::XlaOp(output, ctx->expected_output_dtype(0)); + output_expr.set_contents(std::move(output_contents)); + ctx->SetOutputExpression(0, output_expr); + return; + } + ctx->SetOutput(0, output); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index 108bf3848aae93..8f8a34899fc46a 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -17,8 +17,9 @@ limitations under the License. #include #include - +#include "absl/algorithm/container.h" #include "absl/status/statusor.h" +#include "tensorflow/compiler/tf2xla/symbolic_content_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/hlo/builder/lib/constants.h" @@ -26,6 +27,7 @@ limitations under the License. #include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/primitive_util.h" +#include "xla/shape_util.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" @@ -33,17 +35,83 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { +template +xla::DExpr GetScalarExpr(const XlaExpression& expression, + const xla::LiteralSlice& literal) { + const auto& contents = expression.contents(); + if (!contents.empty() && contents[0]) { + return contents[0]; + } + return xla::DExpr::Const(literal.Get({})); +} + +bool HasStaticScalarContent(const XlaExpression& expression) { + const auto& contents = expression.contents(); + return contents.empty() || + (contents[0] && contents[0]->is_constant()); +} + +bool HasDynamicContent(const XlaExpression& expression) { + return absl::c_any_of(expression.contents(), [](const xla::DExpr& expr) { + return expr && expr->is_dynamic(); + }); +} + +template +std::vector BuildRangeContents(const XlaExpression& start_expr, + const XlaExpression& delta_expr, + const xla::LiteralSlice& start, + const xla::LiteralSlice& delta, + int64_t size) { + std::vector contents; + contents.reserve(size); + xla::DExpr start_symbol = GetScalarExpr(start_expr, start); + xla::DExpr delta_symbol = GetScalarExpr(delta_expr, delta); + for (int64_t i = 0; i < size; ++i) { + xla::DExpr offset = xla::DExpr::Const(static_cast(i)); + contents.push_back((start_symbol + (delta_symbol * offset).simplify()).simplify()); + } + return contents; +} + +template +xla::DExpr BuildRangeSizeExpr(const XlaExpression& start_expr, + const XlaExpression& limit_expr, + const XlaExpression& delta_expr, + const xla::LiteralSlice& start, + const xla::LiteralSlice& limit, + const xla::LiteralSlice& delta, + int64_t fallback_size) { + xla::DExpr start_symbol = GetScalarExpr(start_expr, start); + xla::DExpr limit_symbol = GetScalarExpr(limit_expr, limit); + xla::DExpr delta_symbol = GetScalarExpr(delta_expr, delta); + + if (delta.Get({}) > 0) { + xla::DExpr diff = (limit_symbol - start_symbol).simplify(); + xla::DExpr adjusted = (diff - 1).simplify(); + xla::DExpr quotient = (adjusted / delta_symbol).simplify(); + return (quotient + 1).simplify(); + } + xla::DExpr step_symbol = (xla::DExpr::Const(0) - delta_symbol).simplify(); + xla::DExpr diff = (start_symbol - limit_symbol).simplify(); + xla::DExpr adjusted = (diff - 1).simplify(); + xla::DExpr quotient = (adjusted / step_symbol).simplify(); + return (quotient + 1).simplify(); +} + // The type-specific part of the implementation of Range. template absl::StatusOr CreateRangeTensor( const xla::LiteralSlice& start_literal, const xla::LiteralSlice& limit_literal, - const xla::LiteralSlice& delta_literal, xla::XlaBuilder* builder) { + const xla::LiteralSlice& delta_literal, xla::XlaBuilder* builder, + xla::DExpr size_expr = xla::DExpr()) { T start = start_literal.Get({}); T limit = limit_literal.Get({}); T delta = delta_literal.Get({}); @@ -70,10 +138,17 @@ absl::StatusOr CreateRangeTensor( : (std::abs(limit - start) - 1) / std::abs(delta) + 1) : std::ceil(std::abs((limit - start) / delta))); - return xla::ConstantR0(builder, start) + - xla::ConstantR0(builder, delta) * - xla::Iota(builder, xla::primitive_util::NativeToPrimitiveType(), - size); + xla::XlaOp iota = + (std::is_integral::value && size_expr) + ? xla::Iota(builder, + xla::ShapeUtil::MakeShape( + xla::primitive_util::NativeToPrimitiveType(), + {size}, {size_expr}), + /*iota_dimension=*/0) + : xla::Iota(builder, xla::primitive_util::NativeToPrimitiveType(), + size); + + return xla::ConstantR0(builder, start) + xla::ConstantR0(builder, delta) * iota; } class RangeOp : public XlaOpKernel { @@ -103,13 +178,48 @@ class RangeOp : public XlaOpKernel { DataType type = input_type(0); absl::StatusOr output; switch (type) { - case DT_INT32: - output = CreateRangeTensor(start, limit, delta, ctx->builder()); + case DT_INT32: { + int32 start_value = start.Get({}); + int32 limit_value = limit.Get({}); + int32 delta_value = delta.Get({}); + int64_t size = static_cast( + limit_value == start_value + ? 0 + : (std::abs(limit_value - start_value) - 1) / + std::abs(delta_value) + + 1); + xla::DExpr size_expr = + HasStaticScalarContent(ctx->InputExpression(2)) + ? BuildRangeSizeExpr(ctx->InputExpression(0), + ctx->InputExpression(1), + ctx->InputExpression(2), start, + limit, delta, size) + : xla::DExpr::Const(size); + output = CreateRangeTensor(start, limit, delta, ctx->builder(), + size_expr); break; - case DT_INT64: - output = - CreateRangeTensor(start, limit, delta, ctx->builder()); + } + case DT_INT64: { + int64_t start_value = start.Get({}); + int64_t limit_value = limit.Get({}); + int64_t delta_value = delta.Get({}); + int64_t size = + limit_value == start_value + ? 0 + : (std::abs(limit_value - start_value) - 1) / + std::abs(delta_value) + + 1; + xla::DExpr size_expr = + HasStaticScalarContent(ctx->InputExpression(2)) + ? BuildRangeSizeExpr(ctx->InputExpression(0), + ctx->InputExpression(1), + ctx->InputExpression(2), start, + limit, delta, size) + : xla::DExpr::Const(size); + output = CreateRangeTensor(start, limit, delta, ctx->builder(), + size_expr); break; + } case DT_FLOAT: output = CreateRangeTensor(start, limit, delta, ctx->builder()); break; @@ -145,7 +255,53 @@ class RangeOp : public XlaOpKernel { } } - ctx->SetOutput(0, output.value()); + const XlaExpression& start_expr = ctx->InputExpression(0); + const XlaExpression& delta_expr = ctx->InputExpression(2); + const bool symbolic_enabled = SymbolicContentEnabled(); + const bool has_dynamic_content = + HasDynamicContent(start_expr) || HasDynamicContent(delta_expr); + + if (type == DT_INT32) { + int32 start_value = start.Get({}); + int32 limit_value = limit.Get({}); + int32 delta_value = delta.Get({}); + if (symbolic_enabled && has_dynamic_content) { + int64_t size = static_cast( + limit_value == start_value + ? 0 + : (std::abs(limit_value - start_value) - 1) / + std::abs(delta_value) + + 1); + auto output_expr = + XlaExpression::XlaOp(output.value(), ctx->expected_output_dtype(0)); + output_expr.set_contents(BuildRangeContents( + start_expr, delta_expr, start, delta, size)); + ctx->SetOutputExpression(0, output_expr); + } else { + ctx->SetOutput(0, output.value()); + } + } else if (type == DT_INT64) { + int64_t start_value = start.Get({}); + int64_t limit_value = limit.Get({}); + int64_t delta_value = delta.Get({}); + if (symbolic_enabled && has_dynamic_content) { + int64_t size = static_cast( + limit_value == start_value + ? 0 + : (std::abs(limit_value - start_value) - 1) / + std::abs(delta_value) + + 1); + auto output_expr = + XlaExpression::XlaOp(output.value(), ctx->expected_output_dtype(0)); + output_expr.set_contents(BuildRangeContents( + start_expr, delta_expr, start, delta, size)); + ctx->SetOutputExpression(0, output_expr); + } else { + ctx->SetOutput(0, output.value()); + } + } else { + ctx->SetOutput(0, output.value()); + } } }; diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 510da9023f6348..3d58c0463f7366 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/tf2xla/kernels/shape_util.h" +#include "tensorflow/compiler/tf2xla/symbolic_content_util.h" #include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -46,6 +47,18 @@ limitations under the License. namespace tensorflow { namespace { +std::vector BuildShapeContents(const TensorShape& input_shape) { + std::vector contents; + contents.reserve(input_shape.dims()); + for (int64_t i = 0; i < input_shape.dims(); ++i) { + xla::DExpr expr = input_shape.get_filled_expression(i); + contents.push_back(expr && expr->is_dynamic() + ? expr + : xla::DExpr::Const(xla::kUnknownContentSentinel)); + } + return contents; +} + class ShapeOp : public XlaOpKernel { public: explicit ShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -58,18 +71,46 @@ class ShapeOp : public XlaOpKernel { const int rank = input_shape.dims(); if (rank != 0) { for (int64_t i = 0; i < rank; ++i) { - operands.push_back(xla::Broadcast( - xla::ConvertElementType(xla::GetDimensionSize(ctx->Input(0), i), - ctx->output_xla_type(0)), - {1})); + xla::DExpr expr = input_shape.get_filled_expression(i); + std::vector content = { + expr && expr->is_dynamic() + ? expr + : xla::DExpr::Const(xla::kUnknownContentSentinel)}; + xla::XlaOp dim_size = xla::GetDimensionSize(ctx->Input(0), i); + if (SymbolicContentEnabled()) { + OP_REQUIRES_OK( + ctx, ctx->builder()->SetInstructionContents(dim_size, content)); + } + xla::XlaOp converted = + xla::ConvertElementType(dim_size, ctx->output_xla_type(0)); + if (SymbolicContentEnabled()) { + OP_REQUIRES_OK( + ctx, ctx->builder()->SetInstructionContents(converted, content)); + } + xla::XlaOp broadcast = xla::Broadcast(converted, {1}); + if (SymbolicContentEnabled()) { + OP_REQUIRES_OK( + ctx, ctx->builder()->SetInstructionContents(broadcast, content)); + } + operands.push_back(broadcast); } - ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), operands, 0)); + xla::XlaOp concat = xla::ConcatInDim(ctx->builder(), operands, 0); + XlaExpression output = + XlaExpression::XlaOp(concat, ctx->expected_output_dtype(0)); + if (SymbolicContentEnabled()) { + output.set_contents(BuildShapeContents(input_shape)); + } + ctx->SetOutputExpression(0, output); } else { // Rank 0 won't have dynamic size dimension, use constant output. Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()})); OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant)); - ctx->SetConstantOutput(0, shape_constant); + XlaExpression output = XlaExpression::Constant(shape_constant); + if (SymbolicContentEnabled()) { + output.set_contents(BuildShapeContents(input_shape)); + } + ctx->SetOutputExpression(0, output); } } @@ -196,19 +237,47 @@ class ShapeNOp : public XlaOpKernel { // Each dimension can be dynamic, so use GetDimensionSize to get the // runtime dimension. for (int64_t dim = 0; dim < rank; ++dim) { - operands.push_back(xla::Broadcast( - xla::ConvertElementType(xla::GetDimensionSize(ctx->Input(i), dim), - ctx->output_xla_type(i)), - {1})); + xla::DExpr expr = input_shape.get_filled_expression(dim); + std::vector content = { + expr && expr->is_dynamic() + ? expr + : xla::DExpr::Const(xla::kUnknownContentSentinel)}; + xla::XlaOp dim_size = xla::GetDimensionSize(ctx->Input(i), dim); + if (SymbolicContentEnabled()) { + OP_REQUIRES_OK( + ctx, ctx->builder()->SetInstructionContents(dim_size, content)); + } + xla::XlaOp converted = + xla::ConvertElementType(dim_size, ctx->output_xla_type(i)); + if (SymbolicContentEnabled()) { + OP_REQUIRES_OK( + ctx, ctx->builder()->SetInstructionContents(converted, content)); + } + xla::XlaOp broadcast = xla::Broadcast(converted, {1}); + if (SymbolicContentEnabled()) { + OP_REQUIRES_OK( + ctx, ctx->builder()->SetInstructionContents(broadcast, content)); + } + operands.push_back(broadcast); } - ctx->SetOutput(i, xla::ConcatInDim(ctx->builder(), operands, 0)); + XlaExpression output = + XlaExpression::XlaOp(xla::ConcatInDim(ctx->builder(), operands, 0), + ctx->expected_output_dtype(i)); + if (SymbolicContentEnabled()) { + output.set_contents(BuildShapeContents(input_shape)); + } + ctx->SetOutputExpression(i, output); } else { // Rank 0 won't have dynamic size dimension, use constant output. Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()})); OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant)); - ctx->SetConstantOutput(i, shape_constant); + XlaExpression output = XlaExpression::Constant(shape_constant); + if (SymbolicContentEnabled()) { + output.set_contents(BuildShapeContents(input_shape)); + } + ctx->SetOutputExpression(i, output); } } } diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index 57f48987e8a6b2..0dee93660c6ce9 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -18,8 +18,10 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" #include "absl/types/span.h" +#include "tensorflow/compiler/tf2xla/symbolic_content_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/hlo/builder/lib/constants.h" @@ -37,6 +39,36 @@ limitations under the License. namespace tensorflow { namespace { +bool TryBuildSlicedContents(const XlaExpression& input_expr, + const TensorShape& input_shape, + absl::Span begin, + absl::Span size, + std::vector* output_contents) { + output_contents->clear(); + const auto& input_contents = input_expr.contents(); + if (input_contents.empty() || input_shape.dims() != 1 || begin.size() != 1 || + size.size() != 1) { + return false; + } + const int64_t start = begin[0]; + const int64_t count = + size[0] == -1 ? input_shape.dim_size(0) - start : size[0]; + for (int64_t i = 0; i < count; ++i) { + const int64_t index = start + i; + if (index < 0 || index >= input_contents.size()) { + output_contents->clear(); + return false; + } + const xla::DExpr& expr = input_contents[index]; + output_contents->push_back(expr ? expr + : xla::DExpr::Const( + xla::kUnknownContentSentinel)); + } + return absl::c_any_of(*output_contents, [](const xla::DExpr& expr) { + return expr && expr->is_dynamic(); + }); +} + class SliceOp : public XlaOpKernel { public: explicit SliceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} @@ -72,6 +104,7 @@ class SliceOp : public XlaOpKernel { if (size[i] == -1) { // A size[i] of -1 means "all elements from begin[i] to dim_size(i)". wrapped_size[i] = input_shape.dim_size(i) - begin[i]; + wrapped_size_exprs[i] = wrapped_size_exprs[i] = (input_shape.get_filled_expression(i) - begin[i]).simplify(); } else { @@ -112,10 +145,20 @@ class SliceOp : public XlaOpKernel { for (int i = 0; i < begin.size(); ++i) { limits.push_back(begin[i] + wrapped_size[i]); exprs.push_back((begin_exprs[i] + wrapped_size_exprs[i]).simplify()); + exprs.push_back((begin_exprs[i] + wrapped_size_exprs[i]).simplify()); } std::vector strides(begin.size(), 1); auto slice = xla::Slice(ctx->Input(0), begin, limits, begin_exprs, exprs, strides); + std::vector output_contents; + const bool has_output_contents = + SymbolicContentEnabled() && + TryBuildSlicedContents(ctx->InputExpression(0), input_shape, begin, + size, &output_contents); + if (has_output_contents) { + OP_REQUIRES_OK( + ctx, ctx->builder()->SetInstructionContents(slice, output_contents)); + } // Check for slice on dynamic dimensions. std::vector size_is_dynamic; OP_REQUIRES_OK( @@ -132,9 +175,21 @@ class SliceOp : public XlaOpKernel { {}); slice = xla::SetDimensionSize(slice, dynamic_size, i); + if (has_output_contents) { + OP_REQUIRES_OK( + ctx, + ctx->builder()->SetInstructionContents(slice, output_contents)); + } } } } + if (has_output_contents) { + auto output_expr = + XlaExpression::XlaOp(slice, ctx->expected_output_dtype(0)); + output_expr.set_contents(std::move(output_contents)); + ctx->SetOutputExpression(0, output_expr); + return; + } ctx->SetOutput(0, slice); } else { // When a size is -1, we take rest of the dimension according to diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index bbe7f63a80a823..ea18811d4483b3 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/symbolic_content_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -48,6 +49,38 @@ namespace tensorflow { namespace { using errors::InvalidArgument; +bool TryBuildSlicedContents(const XlaExpression& input_expr, + const TensorShape& input_shape, + const absl::InlinedVector& begin, + const absl::InlinedVector& strides, + const TensorShape& final_shape, + std::vector* output_contents) { + output_contents->clear(); + const auto& input_contents = input_expr.contents(); + if (input_contents.empty() || input_shape.dims() != 1 || begin.size() != 1 || + strides.size() != 1) { + return false; + } + + const int64_t output_elements = final_shape.num_elements(); + const int64_t start = begin[0]; + const int64_t stride = strides[0]; + for (int64_t i = 0; i < output_elements; ++i) { + const int64_t index = start + i * stride; + if (index < 0 || index >= input_contents.size()) { + output_contents->clear(); + return false; + } + const xla::DExpr& expr = input_contents[index]; + output_contents->push_back(expr ? expr + : xla::DExpr::Const( + xla::kUnknownContentSentinel)); + } + return absl::c_any_of(*output_contents, [](const xla::DExpr& expr) { + return expr && expr->is_dynamic(); + }); +} + class StridedSliceOp : public XlaOpKernel { public: explicit StridedSliceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -83,7 +116,7 @@ class StridedSliceOp : public XlaOpKernel { i, input_shape.dim_size(shape_spec.output_to_processing_mapping[i])); partial_final_shape.set_expression( - i, input_shape.get_filled_expression( + i, input_shape.get_expression( shape_spec.output_to_processing_mapping[i])); } } @@ -101,7 +134,7 @@ class StridedSliceOp : public XlaOpKernel { // dimension is unknown, we use input shape as bound. partial_processing_shape.set_dim(i, input_shape.dim_size(i)); partial_processing_shape.set_expression(i, - input_shape.get_filled_expression(i)); + input_shape.get_expression(i)); } } TensorShape processing_shape; @@ -224,11 +257,11 @@ class StridedSliceOp : public XlaOpKernel { slice = xla::DynamicSlice(slice, start_indices, processing_shape.dim_sizes(), - processing_shape.get_filled_expressions()); + processing_shape.get_expressions()); // new_axis_mask_, ellipsis_mask_ and shrink_axis_mask_ may add or remove // size 1 dims of a shape. slice = xla::Reshape(slice, final_shape.dim_sizes(), - final_shape.get_filled_expressions()); + final_shape.get_expressions()); for (int64_t i = 0; i < final_shape.dims(); ++i) { int64 processing_shape_dim = shape_spec.output_to_processing_mapping[i]; // If processing_shape_dim is -1, it means the output dimension was newly @@ -320,6 +353,7 @@ class StridedSliceOp : public XlaOpKernel { // Negative stride: swap begin and end, add 1 because the interval // is semi-open, and mark the dimension to be reversed. xla::DExpr input_expr = input_shape.get_filled_expression(i); + xla::DExpr input_expr = input_shape.get_filled_expression(i); slice_begin.push_back(input_shape.dim_size(i) - begin[i] - 1); slice_begin_expr.push_back( (input_expr - begin_expr[i] - xla::DExpr::Const(1)).simplify()); @@ -340,6 +374,15 @@ class StridedSliceOp : public XlaOpKernel { } slice = xla::Slice(slice, slice_begin, slice_end, slice_begin_expr, slice_end_expr, slice_strides); + std::vector output_contents; + const bool has_output_contents = + SymbolicContentEnabled() && + TryBuildSlicedContents(ctx->InputExpression(0), input_shape, begin, + strides, final_shape, &output_contents); + if (has_output_contents) { + OP_REQUIRES_OK( + ctx, ctx->builder()->SetInstructionContents(slice, output_contents)); + } auto operand_shape_or = ctx->builder()->GetShape(ctx->Input(0)); OP_REQUIRES_OK(ctx, operand_shape_or.status()); xla::Shape xla_shape = operand_shape_or.value(); @@ -353,8 +396,19 @@ class StridedSliceOp : public XlaOpKernel { ends_are_dynamic, [](bool dynamic) { return !dynamic; }); // Static output shape, return a static slice. slice = xla::Reshape(slice, final_shape.dim_sizes(), - final_shape.get_filled_expressions()); + final_shape.get_expressions()); + if (has_output_contents) { + OP_REQUIRES_OK( + ctx, ctx->builder()->SetInstructionContents(slice, output_contents)); + } if (xla_shape.is_static() && ends_are_static) { + if (has_output_contents) { + auto output_expr = + XlaExpression::XlaOp(slice, ctx->expected_output_dtype(0)); + output_expr.set_contents(std::move(output_contents)); + ctx->SetOutputExpression(0, output_expr); + return; + } ctx->SetOutput(0, slice); return; } @@ -413,8 +467,20 @@ class StridedSliceOp : public XlaOpKernel { xla::Sub(operand_size, xla::ConstantR0( ctx->builder(), begin[input_index])), i); + if (has_output_contents) { + OP_REQUIRES_OK( + ctx, ctx->builder()->SetInstructionContents(slice, + output_contents)); + } } } + if (has_output_contents) { + auto output_expr = + XlaExpression::XlaOp(slice, ctx->expected_output_dtype(0)); + output_expr.set_contents(std::move(output_contents)); + ctx->SetOutputExpression(0, output_expr); + return; + } ctx->SetOutput(0, slice); return; } else { @@ -499,6 +565,7 @@ class StridedSliceGradOp : public XlaOpKernel { // dynamic update slice. auto input_sizes_padded = input_shape.dim_sizes(); auto input_exprs_padded = input_exprs; + auto input_exprs_padded = input_exprs; bool need_padding = false; for (int64_t i = 0; i < processing_shape.dims(); ++i) { if (processing_shape.dim_size(i) == -1) { @@ -547,7 +614,7 @@ class StridedSliceGradOp : public XlaOpKernel { zero = xla::Broadcast(zero, input_sizes_padded, input_exprs_padded); grad = xla::Reshape(grad, processing_shape.dim_sizes(), - processing_shape.get_filled_expressions()); + processing_shape.get_expressions()); grad = xla::DynamicUpdateSlice(zero, grad, begins); if (need_padding) { // We padded the input shape to avoid OOB when DUS. Now slice out the @@ -617,7 +684,7 @@ class StridedSliceGradOp : public XlaOpKernel { // Undo any new/shrink axes. grad = xla::Reshape(grad, processing_shape.dim_sizes(), - processing_shape.get_filled_expressions()); + processing_shape.get_expressions()); // Pad the input gradients. absl::InlinedVector dimensions_to_reverse; diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index a7a1a438f95b9e..f8d58b71e660fe 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/kernels/shape_util.h" #include "tensorflow/compiler/tf2xla/lib/scatter.h" +#include "tensorflow/compiler/tf2xla/symbolic_content_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_resource.h" @@ -50,6 +51,18 @@ absl::Status ValidateAssignUpdateVariableOpShapes(XlaOpKernelContext* ctx) { return absl::OkStatus(); } +std::vector BuildVariableShapeContents(const TensorShape& shape) { + std::vector contents; + contents.reserve(shape.dims()); + for (int i = 0; i < shape.dims(); ++i) { + xla::DExpr expr = shape.get_filled_expression(i); + contents.push_back(expr && expr->is_dynamic() + ? expr + : xla::DExpr::Const(xla::kUnknownContentSentinel)); + } + return contents; +} + class VarIsInitializedOp : public XlaOpKernel { public: explicit VarIsInitializedOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} @@ -75,7 +88,11 @@ class VariableShapeOp : public XlaOpKernel { ctx->GetVariableTypeAndShape(0, &variable_dtype, &shape)); Tensor shape_constant(out_dtype_, TensorShape({shape.dims()})); OP_REQUIRES_OK(ctx, TensorShapeToConstant(shape, &shape_constant)); - ctx->SetConstantOutput(0, shape_constant); + auto output = XlaExpression::Constant(shape_constant); + if (SymbolicContentEnabled()) { + output.set_contents(BuildVariableShapeContents(shape)); + } + ctx->SetOutputExpression(0, output); } private: diff --git a/tensorflow/compiler/tf2xla/symbolic_content_util.h b/tensorflow/compiler/tf2xla/symbolic_content_util.h new file mode 100644 index 00000000000000..3dd471a4e9ab02 --- /dev/null +++ b/tensorflow/compiler/tf2xla/symbolic_content_util.h @@ -0,0 +1,29 @@ +/* Copyright 2026 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_SYMBOLIC_CONTENT_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_SYMBOLIC_CONTENT_UTIL_H_ + +#include "tensorflow/compiler/jit/flags.h" + +namespace tensorflow { + +inline bool SymbolicContentEnabled() { + return GetMarkForCompilationPassFlags()->tf_xla_enable_symbolic_content; +} + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_SYMBOLIC_CONTENT_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/xla_argument.cc b/tensorflow/compiler/tf2xla/xla_argument.cc index 8b91dd3870b7d5..71ea93bf4fc628 100644 --- a/tensorflow/compiler/tf2xla/xla_argument.cc +++ b/tensorflow/compiler/tf2xla/xla_argument.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_argument.h" +#include "absl/algorithm/container.h" #include "llvm/ADT/STLExtras.h" namespace tensorflow { @@ -46,6 +47,14 @@ bool XlaArgument::operator==(const XlaArgument& other) const { if (constant_value.shape() != other.constant_value.shape()) { return false; } + if (!absl::c_equal( + constant_value_expressions, other.constant_value_expressions, + [](const xla::ExpressionProto& lhs, + const xla::ExpressionProto& rhs) { + return lhs.SerializeAsString() == rhs.SerializeAsString(); + })) { + return false; + } if (is_same_data_across_replicas != other.is_same_data_across_replicas) { return false; } diff --git a/tensorflow/compiler/tf2xla/xla_argument.h b/tensorflow/compiler/tf2xla/xla_argument.h index bbdc7df77e1b75..6dc37ea1c0f963 100644 --- a/tensorflow/compiler/tf2xla/xla_argument.h +++ b/tensorflow/compiler/tf2xla/xla_argument.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_ARGUMENT_H_ #define TENSORFLOW_COMPILER_TF2XLA_XLA_ARGUMENT_H_ +#include + #include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" @@ -23,6 +25,7 @@ limitations under the License. #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/ir/hlo_sharding.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" namespace tensorflow { @@ -81,6 +84,11 @@ struct XlaArgument { // reinterpret as coming from a dynamic expression instead of the literal. int64_t dynamic_constant_index = -1; std::optional dynamic_constant_expr; + + // Symbolic expressions for each element of a compile-time constant. + // This is only used for shape-like integer tensors crossing cluster + // boundaries. + std::vector constant_value_expressions; // The upper bounds of the value. std::optional value_bound; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 3e33fdecccc629..e71449f354b71b 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -39,6 +39,7 @@ limitations under the License. #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/shape_inference.h" +#include "tensorflow/compiler/tf2xla/symbolic_content_util.h" #include "tensorflow/compiler/jit/xla_compile_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h" @@ -1130,6 +1131,25 @@ absl::Status XlaCompiler::BuildArguments( arg_expression.set_dynamic_constant_expr(*arg.dynamic_constant_expr); } } + if (!arg.constant_value_expressions.empty()) { + LOG(INFO) << "BuildArguments attaching " + << arg.constant_value_expressions.size() + << " constant_value_expressions to constant arg " << i + << " (" << arg.name << ")"; + // Preserve symbolic per-element metadata for shape-like constants so + // later tf2xla consumers can recover dynamic contents from them. + std::vector contents; + contents.reserve(arg.constant_value_expressions.size()); + for (const xla::ExpressionProto& expr : + arg.constant_value_expressions) { + xla::DExpr parsed = xla::DExprFromProto(expr); + contents.push_back(parsed && parsed->is_dynamic() + ? std::move(parsed) + : xla::DExpr::Const( + xla::kUnknownContentSentinel)); + } + arg_expression.set_contents(std::move(contents)); + } break; case XlaCompiler::Argument::kInvalid: return errors::Internal( diff --git a/tensorflow/compiler/tf2xla/xla_expression.cc b/tensorflow/compiler/tf2xla/xla_expression.cc index a6485b593e2591..5e12d38e191a09 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.cc +++ b/tensorflow/compiler/tf2xla/xla_expression.cc @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/compiler/tf2xla/symbolic_content_util.h" #include "xla/hlo/builder/value_inference.h" -#include "tsl/platform/protobuf.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" @@ -76,6 +76,50 @@ XlaExpression XlaExpression::Resource(XlaResource* resource) { return e; } +void XlaExpression::set_contents(std::vector contents) { + switch (kind_) { + case Kind::kXlaOp: + case Kind::kTensorList: + if (handle_.valid() && !handle_.IsUninitialized()) { + if (SymbolicContentEnabled()) { + auto status = handle_.builder()->SetInstructionContents( + handle_, std::move(contents)); + if (!status.ok()) { + LOG(INFO) << "Failed to set XlaOp contents: " << status; + } + return; + } + break; + } + break; + case Kind::kInvalid: + case Kind::kConstant: + case Kind::kResource: + break; + } + local_contents_ = std::move(contents); +} + +absl::Span XlaExpression::contents() const { + switch (kind_) { + case Kind::kXlaOp: + case Kind::kTensorList: + if (SymbolicContentEnabled() && handle_.valid() && + !handle_.IsUninitialized()) { + auto contents_or = handle_.builder()->GetInstructionContents(handle_); + if (contents_or.ok()) { + return **contents_or; + } + } + break; + case Kind::kInvalid: + case Kind::kConstant: + case Kind::kResource: + break; + } + return local_contents_; +} + string XlaExpression::HumanString() const { switch (kind_) { case Kind::kInvalid: @@ -98,28 +142,25 @@ xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const { xla::BorrowingLiteral literal; TF_RETURN_IF_ERROR( HostTensorToBorrowingLiteral(*constant_value_, &literal)); - if (!dynamic_constant_index_.has_value() || - !dynamic_constant_expr_.has_value()) { - return xla::ConstantLiteral(builder, literal); + xla::XlaOp op = xla::ConstantLiteral(builder, literal); + if (dynamic_constant_index_.has_value() && + dynamic_constant_expr_.has_value()) { + xla::FrontendAttributes attributes = builder->frontend_attributes(); + (*attributes.mutable_map())["dynamic_constant_index"] = + std::to_string(*dynamic_constant_index_); + xla::ExpressionProto expr_proto; + dynamic_constant_expr_->to_proto(&expr_proto); + std::string expr_textproto = + tsl::LegacyUnredactedShortDebugString(expr_proto); + (*attributes.mutable_map())["dynamic_constant_expr"] = expr_textproto; + xla::XlaScopedFrontendAttributesAssignment assign_frontend_attributes( + builder, attributes); + op = xla::ConstantLiteral(builder, literal); } - - xla::FrontendAttributes attributes = builder->frontend_attributes(); - (*attributes.mutable_map())["dynamic_constant_index"] = - std::to_string(*dynamic_constant_index_); - xla::ExpressionProto expr_proto; - dynamic_constant_expr_->to_proto(&expr_proto); - std::string expr_textproto = - tsl::LegacyUnredactedShortDebugString(expr_proto); - VLOG(1) << "AsXlaOp: expr_textproto is " << expr_textproto - << " ShortDebugString is " << expr_proto.ShortDebugString(); - (*attributes.mutable_map())["dynamic_constant_expr"] = expr_textproto; - VLOG(1) << "Marking HLO constant with dynamic_constant_index=" - << *dynamic_constant_index_ - << " dynamic_constant_expr=" - << expr_proto.ShortDebugString(); - xla::XlaScopedFrontendAttributesAssignment assign_frontend_attributes( - builder, attributes); - return xla::ConstantLiteral(builder, literal); + if (!local_contents_.empty()) { + TF_RETURN_IF_ERROR(builder->SetInstructionContents(op, local_contents_)); + } + return op; } case Kind::kTensorList: TF_FALLTHROUGH_INTENDED; diff --git a/tensorflow/compiler/tf2xla/xla_expression.h b/tensorflow/compiler/tf2xla/xla_expression.h index a9c20705e6bffc..0230cc9fbd4e7a 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.h +++ b/tensorflow/compiler/tf2xla/xla_expression.h @@ -16,11 +16,15 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_ #define TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_ +#include + #include "absl/types/optional.h" +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/xla_resource.h" #include "xla/client/client.h" #include "xla/hlo/builder/value_inference.h" #include "xla/hlo/builder/xla_builder.h" +#include "xla/shape.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/statusor.h" @@ -125,6 +129,13 @@ class XlaExpression { dynamic_constant_expr_ = std::move(expr); } + // Set symbolic content metadata for expressions whose values should retain + // links to symbolic dimensions across shape-tensor flows. + void set_contents(std::vector contents); + + // Return symbolic content metadata. + absl::Span contents() const; + XlaResource* resource() const { return resource_; } // Returns a human-readable summary of the expression. @@ -182,6 +193,10 @@ class XlaExpression { std::optional dynamic_constant_index_; std::optional dynamic_constant_expr_; + // Symbolic expressions describing tensor contents when this expression is + // used as a shape-like value. + std::vector local_contents_; + // The resource, if kind_ == kResource. Not owned. XlaResource* resource_ = nullptr; }; diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index e022b2fbec258b..a1dce933db5309 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -464,6 +464,12 @@ absl::Status XlaOpKernelContext::ConstantInputAsShape( ", result: ", num_elements); } *shape = TensorShape(dims); + const auto& contents = InputExpression(index).contents(); + for (int i = 0; i < shape->dims() && i < contents.size(); ++i) { + if (contents[i] && contents[i]->is_dynamic()) { + shape->set_expression(i, contents[i]); + } + } return absl::OkStatus(); } diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index 151413d0f3bca1..e5ecae12669e65 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape_expr.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" @@ -41,6 +42,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/denormal.h" #include "tensorflow/core/platform/setround.h" #include "tensorflow/core/public/session_options.h" @@ -52,6 +54,31 @@ namespace { const char kScopedAllocatorAttrName[] = "_scoped_allocator"; const char kXlaShapeDerivedAttrName[] = "_xla_shape_derived"; +bool GetShapeFromDynamicAncestor(const Node* node, + TensorShapeProto* out_shape); + +bool IsDynamicExpressionProto(const ExpressionProto& proto) { + switch (proto.node_type_case()) { + case ExpressionProto::kVariableId: + return true; + case ExpressionProto::kAddNode: + return IsDynamicExpressionProto(proto.add_node().lhs()) || + IsDynamicExpressionProto(proto.add_node().rhs()); + case ExpressionProto::kSubNode: + return IsDynamicExpressionProto(proto.sub_node().lhs()) || + IsDynamicExpressionProto(proto.sub_node().rhs()); + case ExpressionProto::kMulNode: + return IsDynamicExpressionProto(proto.mul_node().lhs()) || + IsDynamicExpressionProto(proto.mul_node().rhs()); + case ExpressionProto::kDivNode: + return IsDynamicExpressionProto(proto.div_node().lhs()) || + IsDynamicExpressionProto(proto.div_node().rhs()); + case ExpressionProto::kConstantValue: + case ExpressionProto::NODE_TYPE_NOT_SET: + return false; + } +} + // For stateless RNGs ops, they are pure but device-dependent. Those ops are not // constant-foldable. static absl::flat_hash_set* kBlockList = @@ -246,6 +273,13 @@ bool IsConstantFoldable( int64_t max_constant_size_in_bytes, std::unordered_map>* shape_replacement_map) { + TensorShapeProto dynamic_shape; + if (GetShapeFromDynamicAncestor(n, &dynamic_shape)) { + VLOG(1) << "Skipping constant folding for dynamic shape-derived node " + << n->name() << " op=" << n->type_string() + << " inferred_shape=" << dynamic_shape.DebugString(); + return false; + } if (n->attrs().FindByString(kXlaShapeDerivedAttrName) != nullptr) { VLOG(1) << "Skipping constant folding for shape-derived node " << n->name() << " op=" << n->type_string(); @@ -468,6 +502,29 @@ bool GetShapeFromArgNode(const Node* node, TensorShapeProto* out_shape) { return false; } +bool GetShapeFromDynamicAncestor(const Node* node, TensorShapeProto* out_shape) { + std::vector stack = {node}; + absl::flat_hash_set visited; + while (!stack.empty()) { + const Node* current = stack.back(); + stack.pop_back(); + if (!visited.insert(current).second) { + continue; + } + if ((IsShapeOp(current) || + current->attrs().FindByString(kXlaShapeDerivedAttrName) != nullptr) && + GetShapeFromArgNode(current, out_shape)) { + return true; + } + for (const Edge* edge : current->in_edges()) { + if (!edge->IsControlEdge()) { + stack.push_back(edge->src()); + } + } + } + return false; +} + // Replaces constant-foldable shape node n by a vector of constants in // constant_graph, which is being built up for subsequent evaluation of constant // propagation. node_map is the mapping of nodes in the original graph to nodes @@ -480,10 +537,13 @@ void AddShapeNodeToConstantGraph( shape_replacement_map, std::unordered_map>* node_map, const ConstantFoldNameGenerator& generate_new_name, Graph* constant_graph) { - TensorShapeProto user_inferred_shape; - bool has_dynamic = GetShapeFromArgNode(n, &user_inferred_shape); - + const bool has_dynamic = GetShapeFromDynamicAncestor(n, &user_inferred_shape); + if (has_dynamic) { + LOG(INFO) << "AddShapeNodeToConstantGraph preserving dynamic metadata for " + << n->name() << " op=" << n->type_string() + << " inferred_shape=" << user_inferred_shape.DebugString(); + } std::vector& added = (*node_map)[n]; const string& node_name = n->name(); for (const Tensor& t : shape_replacement_map.at(n)) { @@ -492,9 +552,11 @@ void AddShapeNodeToConstantGraph( auto builder = NodeDefBuilder(generate_new_name(constant_graph, node_name), "Const") .Attr("dtype", t.dtype()) - .Attr("has_dynamic", has_dynamic) - .Attr("user_inferred_shape", user_inferred_shape) .Attr("value", t); + if (has_dynamic) { + builder.Attr("has_dynamic", has_dynamic) + .Attr("user_inferred_shape", user_inferred_shape); + } NodeDef def; CHECK(builder.Finalize(&def).ok()); Node* constant_node; @@ -613,15 +675,22 @@ bool ReplaceTensorWithConstant( } } const string& node_name = n->name(); - Node* constant_node; - TensorShapeProto user_inferred_shape; - bool has_dynamic = GetShapeFromArgNode(tensor.first, &user_inferred_shape); + const bool has_dynamic = + GetShapeFromDynamicAncestor(tensor.first, &user_inferred_shape); + if (has_dynamic) { + LOG(INFO) << "ReplaceTensorWithConstant preserving dynamic metadata for " + << tensor.first->name() << " output=" << tensor.second + << " inferred_shape=" << user_inferred_shape.DebugString(); + } + Node* constant_node; auto builder = NodeDefBuilder(generate_new_name(graph, node_name), "Const") .Attr("dtype", constant.dtype()) - .Attr("has_dynamic", has_dynamic) - .Attr("user_inferred_shape", user_inferred_shape) .Attr("value", constant); + if (has_dynamic) { + builder.Attr("has_dynamic", has_dynamic) + .Attr("user_inferred_shape", user_inferred_shape); + } if (partition_device) { builder.Device(partition_device->name()); } diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index 567a7c00171d79..57b0b0c39ecb40 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -742,6 +742,7 @@ cc_library( ":tensor_shape_proto_cc", "//tensorflow/core:lib", "@local_xla//xla:parse_flags_from_env", + "@local_xla//xla:shape_util", "@local_xla//xla/tsl/util:command_line_flags", ], alwayslink = 1, diff --git a/tensorflow/core/framework/tensor_shape.cc b/tensorflow/core/framework/tensor_shape.cc index 163fd8f9b7b826..e805fed23db7ea 100644 --- a/tensorflow/core/framework/tensor_shape.cc +++ b/tensorflow/core/framework/tensor_shape.cc @@ -821,7 +821,6 @@ void TensorShapeBase::RemoveDimRange(int begin, int end) { new_exprs.resize(new_rank); } - ClearAllButDataType(); set_expressions(new_exprs); for (auto dval : vals) { @@ -864,7 +863,6 @@ absl::Status TensorShapeBase::RemoveDimRangeWithStatus(int begin, std::vector new_exprs(get_expressions().begin(), get_expressions().end()); - if (begin < static_cast(new_exprs.size())) { int64_t expr_end = end; if (expr_end > static_cast(new_exprs.size())) { @@ -876,14 +874,15 @@ absl::Status TensorShapeBase::RemoveDimRangeWithStatus(int begin, } vals.erase(vals.begin() + begin, vals.begin() + end); - ClearAllButDataType(); const int64_t new_rank = vals.size(); if (new_exprs.size() > static_cast(new_rank)) { new_exprs.resize(new_rank); } + ClearAllButDataType(); set_expressions(new_exprs); + absl::Status s = absl::OkStatus(); for (auto dval : vals) { s.Update(AddDimWithStatus(dval)); diff --git a/third_party/xla/xla/hlo/builder/xla_builder.cc b/third_party/xla/xla/hlo/builder/xla_builder.cc index 3841206a9ebb8f..339b799f93d769 100644 --- a/third_party/xla/xla/hlo/builder/xla_builder.cc +++ b/third_party/xla/xla/hlo/builder/xla_builder.cc @@ -56,6 +56,7 @@ limitations under the License. #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/permutation_util.h" +#include "xla/printer.h" #include "xla/primitive_util.h" #include "xla/service/hlo.pb.h" #include "xla/service/shape_inference.h" @@ -518,6 +519,20 @@ static std::string ShapeToString(const ShapeProto& shape) { return absl::StrCat("[", absl::StrJoin(shape.dimensions(), ", "), "]"); } +static std::vector ContentsToProto( + absl::Span contents) { + std::vector protos; + protos.reserve(contents.size()); + for (const DExpr& expr : contents) { + ExpressionProto proto; + if (expr) { + expr.to_proto(&proto); + } + protos.push_back(std::move(proto)); + } + return protos; +} + void XlaBuilder::ToStringHelper(std::string* out, int ident, int64_t op_handle) const { const HloInstructionProto& instr = @@ -707,6 +722,37 @@ absl::Status XlaBuilder::SetInstructionFrontendAttribute(const XlaOp op, return absl::OkStatus(); } +absl::Status XlaBuilder::SetInstructionContents(XlaOp op, + std::vector contents) { + auto it = handle_to_index_.find(op.handle()); + if (it == handle_to_index_.end()) { + return InvalidArgument("No XlaOp with handle %d", op.handle()); + } + const bool has_dynamic_content = + absl::c_any_of(contents, [](const DExpr& expr) { + return expr && expr->is_dynamic(); + }); + if (!has_dynamic_content) { + LOG(INFO) << "SetInstructionContents clearing non-dynamic contents for op " + << op.handle() << " size=" << contents.size(); + contents.clear(); + } else { + LOG(INFO) << "SetInstructionContents keeping dynamic contents for op " + << op.handle() << " size=" << contents.size(); + } + instruction_contents_.at(it->second) = std::move(contents); + return absl::OkStatus(); +} + +absl::StatusOr*> XlaBuilder::GetInstructionContents( + XlaOp op) const { + auto it = handle_to_index_.find(op.handle()); + if (it == handle_to_index_.end()) { + return InvalidArgument("No XlaOp with handle %d", op.handle()); + } + return &instruction_contents_.at(it->second); +} + absl::Status XlaBuilder::SetInstructionSharding( XlaOp op, const std::optional& sharding) { TF_ASSIGN_OR_RETURN(auto instr_proto, LookUpMutableInstruction(op)); @@ -784,7 +830,13 @@ absl::StatusOr XlaBuilder::Build( *entry.mutable_program_shape() = program_shape.ToProto(); entry.set_root_id(root_id); - for (auto& instruction : instructions_) { + for (size_t index = 0; index < instructions_.size(); ++index) { + auto& instruction = instructions_[index]; + if (!instruction_contents_[index].empty()) { + for (const auto& content : ContentsToProto(instruction_contents_[index])) { + *instruction.add_contents() = content; + } + } // Ensures that the instruction names are unique among the whole graph. instruction.set_name( GetFullName(instruction.name(), kNameSeparator, instruction.id())); @@ -810,6 +862,7 @@ absl::StatusOr XlaBuilder::Build( // Clear data held by this builder. this->instructions_.clear(); this->instruction_shapes_.clear(); + this->instruction_contents_.clear(); this->handle_to_index_.clear(); this->embedded_.clear(); this->parameter_numbers_.clear(); @@ -4945,6 +4998,7 @@ absl::StatusOr XlaBuilder::AddInstruction( TF_ASSIGN_OR_RETURN(Shape shape, Shape::FromProto(instructions_.back().shape())); instruction_shapes_.push_back(std::make_unique(std::move(shape))); + instruction_contents_.push_back({}); XlaOp op(handle, this); return op; diff --git a/third_party/xla/xla/hlo/builder/xla_builder.h b/third_party/xla/xla/hlo/builder/xla_builder.h index 7abdf120e4c286..d7cb3ba4222897 100644 --- a/third_party/xla/xla/hlo/builder/xla_builder.h +++ b/third_party/xla/xla/hlo/builder/xla_builder.h @@ -484,6 +484,13 @@ class XlaBuilder { absl::Status SetInstructionFrontendAttribute(XlaOp op, std::string attribute, std::string value); + // Associates symbolic contents metadata with a specific instruction. + absl::Status SetInstructionContents(XlaOp op, std::vector contents); + + // Returns symbolic contents metadata attached to an instruction, if any. + absl::StatusOr*> GetInstructionContents( + XlaOp op) const; + // Looks up the HloInstruction and sets the sharding. If the sharding already // existed, then its value is updated. // @@ -1201,6 +1208,7 @@ class XlaBuilder { // A cache for the HloInstructionProto shapes, to avoid recreating Shape // objects from protos and to support the GetShapePtr() API. std::vector> instruction_shapes_; + std::vector> instruction_contents_; // Dynamic parameter configuration of this computation. DynamicParameterBinding dynamic_parameter_binding_; diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.cc b/third_party/xla/xla/hlo/ir/hlo_instruction.cc index 0dc7d619afc370..a260980ee199cd 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.cc @@ -108,6 +108,51 @@ absl::Status EraseElementFromVector(PtrVec* container, T value) { container->erase(it); return absl::OkStatus(); } + +DynExpr* DynExprFromProtoForPrint(const ExpressionProto& proto) { + switch (proto.node_type_case()) { + case ExpressionProto::kConstantValue: + return DynExpr::_(proto.constant_value()); + case ExpressionProto::kVariableId: + return DynExpr::V(proto.variable_id()); + case ExpressionProto::kAddNode: { + const auto& add = proto.add_node(); + return new Add(DynExprFromProtoForPrint(add.lhs()), + DynExprFromProtoForPrint(add.rhs())); + } + case ExpressionProto::kSubNode: { + const auto& sub = proto.sub_node(); + return new Sub(DynExprFromProtoForPrint(sub.lhs()), + DynExprFromProtoForPrint(sub.rhs())); + } + case ExpressionProto::kMulNode: { + const auto& mul = proto.mul_node(); + return new Mul(DynExprFromProtoForPrint(mul.lhs()), + DynExprFromProtoForPrint(mul.rhs())); + } + case ExpressionProto::kDivNode: { + const auto& div = proto.div_node(); + return new Div(DynExprFromProtoForPrint(div.lhs()), + DynExprFromProtoForPrint(div.rhs())); + } + case ExpressionProto::NODE_TYPE_NOT_SET: + default: + return nullptr; + } +} + +std::string ContentsExprToString(const ExpressionProto& proto) { + std::unique_ptr expr(DynExprFromProtoForPrint(proto)); + if (expr == nullptr) { + return "_"; + } + if (!expr->is_dynamic() && expr->get_val() == kUnknownContentSentinel) { + return "_"; + } + StringPrinter printer; + expr->print(&printer); + return std::move(printer).ToString(); +} } // namespace HloInstruction::Users::~Users() = default; @@ -1373,6 +1418,14 @@ absl::StatusOr> HloInstruction::CreateFromProto( if (proto.has_frontend_attributes()) { instruction->set_frontend_attributes(proto.frontend_attributes()); } + if (proto.contents_size() > 0) { + std::vector contents; + contents.reserve(proto.contents_size()); + for (const auto& content : proto.contents()) { + contents.push_back(content); + } + instruction->set_contents(std::move(contents)); + } if (proto.has_statistics_viz()) { instruction->set_statistics_viz(proto.statistics_viz()); @@ -4235,6 +4288,18 @@ void HloInstruction::PrintExtraAttributes( FrontendAttributesToString(frontend_attributes())); }); } + if (has_contents()) { + printer.Next([this](Printer* printer) { + printer->Append("contents=["); + for (int64_t i = 0; i < contents().size(); ++i) { + if (i > 0) { + printer->Append(", "); + } + printer->Append(ContentsExprToString(contents()[i])); + } + printer->Append("]"); + }); + } if (opcode() != HloOpcode::kCall) { CHECK(!is_composite()) @@ -4356,6 +4421,9 @@ HloInstructionProto HloInstruction::ToProto() const { } *proto.mutable_frontend_attributes() = frontend_attributes(); + for (const auto& content : contents()) { + *proto.add_contents() = content; + } proto.set_is_composite(is_composite()); *proto.mutable_statistics_viz() = statistics_viz(); diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.h b/third_party/xla/xla/hlo/ir/hlo_instruction.h index 59f542b6ec3b5b..800bc686d7dd4b 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.h +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.h @@ -1883,6 +1883,19 @@ class HloInstruction { return rare()->frontend_attributes; } + void set_contents(std::vector contents) { + if (!has_rare() && contents.empty()) { + return; + } + mutable_rare()->contents = std::move(contents); + } + + const std::vector& contents() const { + return rare()->contents; + } + + bool has_contents() const { return has_rare() && !rare()->contents.empty(); } + std::optional get_frontend_attribute( absl::string_view key) const { auto it = rare()->frontend_attributes.map().find(key); @@ -2506,6 +2519,9 @@ class HloInstruction { // z' = const(20), frontend_attributes={?} FrontendAttributes frontend_attributes; + // Structured symbolic contents attached to this instruction. + std::vector contents; + // Used by kCall to determine if the Call instruction is a composite. bool is_composite; diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_constant_folding.cc b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_constant_folding.cc index 231b6ae2239d8f..0ecd3917cda69b 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_constant_folding.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_constant_folding.cc @@ -230,24 +230,30 @@ absl::StatusOr HloConstantFolding::Run( std::vector marked_constant_operands; for (const HloInstruction* operand : instruction->operands()) { if (operand->opcode() != HloOpcode::kConstant || - !operand->has_frontend_attributes()) { + !operand->has_contents()) { continue; } - const auto& attrs = operand->frontend_attributes().map(); - auto it = attrs.find("dynamic_constant_index"); - if (it == attrs.end()) { + bool has_dynamic_content = false; + for (const auto& content : operand->contents()) { + has_dynamic_content = + content.node_type_case() != ExpressionProto::kConstantValue && + content.node_type_case() != ExpressionProto::NODE_TYPE_NOT_SET; + if (has_dynamic_content) { + break; + } + } + if (!has_dynamic_content) { continue; } source_has_dynamic_constant_marker = true; marked_constant_operands.push_back(absl::StrFormat( - "%s:index=%s literal=%s", operand->name(), it->second, - operand->literal().ToString())); + "%s:contents=%d literal=%s", operand->name(), + operand->contents().size(), operand->literal().ToString())); } if (source_has_dynamic_constant_marker) { VLOG(1) << "Skipping HloConstantFolding for " << instruction->name() << " (" << HloOpcodeString(instruction->opcode()) - << ") because source constant operands carry " - "dynamic_constant_index"; + << ") because source constant operands carry dynamic contents"; VLOG(1) << "Marked constant operands: " << absl::StrJoin(marked_constant_operands, ", "); continue; diff --git a/third_party/xla/xla/service/dynamic_constant_rewriter.cc b/third_party/xla/xla/service/dynamic_constant_rewriter.cc index 315c9d6615f3ba..56305dcb66a0f2 100644 --- a/third_party/xla/xla/service/dynamic_constant_rewriter.cc +++ b/third_party/xla/xla/service/dynamic_constant_rewriter.cc @@ -9,9 +9,7 @@ #include "absl/container/flat_hash_set.h" #include "absl/log/log.h" #include "absl/status/statusor.h" -#include "absl/strings/numbers.h" #include "absl/strings/string_view.h" -#include "tsl/platform/protobuf.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -28,24 +26,7 @@ namespace { absl::StatusOr BuildDynamicConstantReplacement( HloInstruction* constant_instr) { TF_RET_CHECK(constant_instr->opcode() == HloOpcode::kConstant); - TF_RET_CHECK(constant_instr->has_frontend_attributes()); - - const auto& attrs = constant_instr->frontend_attributes().map(); - auto index_it = attrs.find("dynamic_constant_index"); - auto expr_it = attrs.find("dynamic_constant_expr"); - TF_RET_CHECK(index_it != attrs.end()); - TF_RET_CHECK(expr_it != attrs.end()); - - int64_t dynamic_index; - TF_RET_CHECK(absl::SimpleAtoi(index_it->second, &dynamic_index)) - << "Failed to parse dynamic_constant_index=" << index_it->second; - - ExpressionProto expr_proto; - TF_RET_CHECK(tsl::protobuf::TextFormat::ParseFromString(expr_it->second, - &expr_proto)) - << "Failed to parse dynamic_constant_expr=" << expr_it->second; - DExpr expr = DExprFromProto(expr_proto); - TF_RET_CHECK(expr); + TF_RET_CHECK(constant_instr->has_contents()); const Shape& shape = constant_instr->shape(); TF_RET_CHECK(shape.IsArray()); @@ -54,6 +35,35 @@ absl::StatusOr BuildDynamicConstantReplacement( TF_RET_CHECK(shape.dimensions_size() <= 1) << "Only scalar and rank-1 marked constants are supported"; + int64_t dynamic_index = 0; + DExpr expr; + for (int64_t i = 0; i < constant_instr->contents().size(); ++i) { + DExpr candidate = DExprFromProto(constant_instr->contents()[i]); + if (candidate && candidate->is_dynamic()) { + dynamic_index = i; + expr = std::move(candidate); + break; + } + } + + if (!expr) { + const auto& attrs = constant_instr->frontend_attributes().map(); + auto index_it = attrs.find("dynamic_constant_index"); + auto expr_it = attrs.find("dynamic_constant_expr"); + TF_RET_CHECK(index_it != attrs.end()); + TF_RET_CHECK(expr_it != attrs.end()); + + TF_RET_CHECK(absl::SimpleAtoi(index_it->second, &dynamic_index)) + << "Failed to parse dynamic_constant_index=" << index_it->second; + + ExpressionProto expr_proto; + TF_RET_CHECK(tsl::protobuf::TextFormat::ParseFromString(expr_it->second, + &expr_proto)) + << "Failed to parse dynamic_constant_expr=" << expr_it->second; + expr = DExprFromProto(expr_proto); + TF_RET_CHECK(expr); + } + int64_t carrier_bound; if (shape.dimensions_size() == 0) { dynamic_index = 0; @@ -96,8 +106,7 @@ absl::StatusOr BuildDynamicConstantReplacement( HloInstruction* base_constant = computation->AddInstruction(constant_instr->Clone()); - base_constant->erase_frontend_attribute("dynamic_constant_index"); - base_constant->erase_frontend_attribute("dynamic_constant_expr"); + base_constant->set_contents({}); Shape update_shape = ShapeUtil::MakeShape(shape.element_type(), {1}); HloInstruction* update = computation->AddInstruction( HloInstruction::CreateReshape(update_shape, runtime_value)); @@ -109,10 +118,7 @@ absl::StatusOr BuildDynamicConstantReplacement( } bool IsMarkedDynamicConstant(const HloInstruction* instr) { - return instr->opcode() == HloOpcode::kConstant && - instr->has_frontend_attributes() && - instr->get_frontend_attribute("dynamic_constant_index").has_value() && - instr->get_frontend_attribute("dynamic_constant_expr").has_value(); + return instr->opcode() == HloOpcode::kConstant && instr->has_contents(); } } // namespace @@ -133,12 +139,7 @@ absl::StatusOr DynamicConstantRewriter::Run( << " literal=" << instruction->literal().ToString() << " marked=" << is_marked; if (is_marked) { - VLOG(1) << " dynamic_constant_index=" - << *instruction->get_frontend_attribute( - "dynamic_constant_index") - << " dynamic_constant_expr=" - << *instruction->get_frontend_attribute( - "dynamic_constant_expr"); + VLOG(1) << " contents_size=" << instruction->contents().size(); marked_constants.push_back(instruction); } } diff --git a/third_party/xla/xla/service/hlo.proto b/third_party/xla/xla/service/hlo.proto index 713db5886c6f4e..21479c3ae6c7be 100644 --- a/third_party/xla/xla/service/hlo.proto +++ b/third_party/xla/xla/service/hlo.proto @@ -113,7 +113,7 @@ enum CustomCallApiVersion { } // Serialization of HloInstruction. -// Next ID: 92 +// Next ID: 93 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -328,6 +328,9 @@ message HloInstructionProto { // Frontend attributes to pass to the XLA backend. xla.FrontendAttributes frontend_attributes = 68; + // Structured symbolic contents attached to this instruction. + repeated xla.ExpressionProto contents = 92; + // Specifies if all elements updated are guaranteed to be unique by // the caller. bool unique_indices = 69; diff --git a/third_party/xla/xla/shape.cc b/third_party/xla/xla/shape.cc index 5e256a6303d6b1..09cef2b7d6807f 100644 --- a/third_party/xla/xla/shape.cc +++ b/third_party/xla/xla/shape.cc @@ -391,6 +391,10 @@ DynExpr* Div::s() { } std::ostream& operator<<(std::ostream& os, DynExpr* expr) { + if (expr == nullptr) { + os << "_"; + return os; + } StringPrinter printer; expr->print(&printer); os << std::move(printer).ToString(); diff --git a/third_party/xla/xla/shape_dynexpr.h b/third_party/xla/xla/shape_dynexpr.h index a5713d92fbfc2b..16ceb78325ca26 100644 --- a/third_party/xla/xla/shape_dynexpr.h +++ b/third_party/xla/xla/shape_dynexpr.h @@ -40,6 +40,8 @@ enum class DExprKind { kMul, kDiv, }; +inline constexpr int64_t kMissingExpressionSentinel = -999; +inline constexpr int64_t kUnknownContentSentinel = -444; class DynExpr { public: @@ -198,6 +200,10 @@ class Constant : public DynExpr { } DExprKind kind() const override { return DExprKind::kConstant; } void print(xla::Printer* printer) const override { + if (value == kMissingExpressionSentinel) { + printer->Append("_"); + return; + } if (value < 0) { printer->Append("("); }