Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 39 additions & 13 deletions tensorflow/compiler/jit/device_compilation_cluster_signature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,31 @@ limitations under the License.

#include "tensorflow/compiler/jit/device_compilation_cluster_signature.h"

#include "absl/strings/str_cat.h"
#include <string>
#include <utility>
#include <variant>

namespace tensorflow {
namespace {
using Signature = DeviceCompilationClusterSignature;
using ConstantTensor = Signature::ConstantTensor;
using TensorTypeAndShape = Signature::TensorTypeAndShape;

// Functor that converts a Signature's arg to a human readable string.
struct SignatureHumanStringAppender {
explicit SignatureHumanStringAppender(std::string* dest) : dest(dest) {}
std::string* dest;
void operator()(const Tensor& arg) {
absl::StrAppend(dest, "; ", arg.DebugString());
void operator()(const ConstantTensor& arg) {
absl::StrAppend(dest, "; ", arg.value.DebugString());
if (!arg.contents.empty()) {
absl::StrAppend(dest, " contents=[");
for (int i = 0; i < arg.contents.size(); ++i) {
if (i > 0) absl::StrAppend(dest, ",");
absl::StrAppend(dest, arg.contents[i].DebugString());
}
absl::StrAppend(dest, "]");
}
}
void operator()(const TensorTypeAndShape& arg) {
absl::StrAppend(dest, ",", DataTypeString(arg.first));
Expand All @@ -40,18 +50,29 @@ struct SignatureHumanStringAppender {
// Functor that compares the arg values of two different signatures. Returns
// true when the args are not equal.
struct SignatureNotEqual {
bool operator()(const Tensor& arg, const Tensor& other) {
return arg.dtype() != other.dtype() || arg.shape() != other.shape() ||
arg.tensor_data() != other.tensor_data();
bool operator()(const ConstantTensor& arg, const ConstantTensor& other) {
if (arg.value.dtype() != other.value.dtype() ||
arg.value.shape() != other.value.shape() ||
arg.value.tensor_data() != other.value.tensor_data() ||
arg.contents.size() != other.contents.size()) {
return true;
}
for (int i = 0; i < arg.contents.size(); ++i) {
if (arg.contents[i].SerializeAsString() !=
other.contents[i].SerializeAsString()) {
return true;
}
}
return false;
}
bool operator()(const TensorTypeAndShape& arg,
const TensorTypeAndShape& other) {
return arg.first != other.first || arg.second != other.second;
}
bool operator()(const Tensor& arg, const TensorTypeAndShape& other) {
bool operator()(const ConstantTensor& arg, const TensorTypeAndShape& other) {
return true;
}
bool operator()(const TensorTypeAndShape& arg, const Tensor& other) {
bool operator()(const TensorTypeAndShape& arg, const ConstantTensor& other) {
return true;
}
};
Expand All @@ -61,12 +82,16 @@ struct SignatureNotEqual {
struct SignatureHashCombiner {
explicit SignatureHashCombiner(const uint64 h) : h(h) {}
uint64 h;
uint64 operator()(const Tensor& arg) {
h = Hash64Combine(h, std::hash<int>()(static_cast<int>(arg.dtype())));
uint64 operator()(const ConstantTensor& arg) {
h = Hash64Combine(h, std::hash<int>()(static_cast<int>(arg.value.dtype())));
h = Hash64Combine(
h, Hash64(arg.tensor_data().data(), arg.tensor_data().size()));
for (int dim = 0; dim < arg.dims(); ++dim) {
h = Hash64Combine(h, std::hash<int>()(arg.dim_size(dim)));
h, Hash64(arg.value.tensor_data().data(), arg.value.tensor_data().size()));
for (int dim = 0; dim < arg.value.dims(); ++dim) {
h = Hash64Combine(h, std::hash<int>()(arg.value.dim_size(dim)));
}
for (const xla::ExpressionProto& expr : arg.contents) {
std::string serialized = expr.SerializeAsString();
h = Hash64Combine(h, Hash64(serialized.data(), serialized.size()));
}
return h;
}
Expand Down Expand Up @@ -120,7 +145,8 @@ absl::StatusOr<Signature> Signature::Build(
switch (arg.kind) {
case XlaCompiler::Argument::kConstant:
case XlaCompiler::Argument::kConstantResource:
signature.args.push_back(arg.constant_value);
signature.args.push_back(
ConstantTensor{arg.constant_value, arg.constant_value_expressions});
break;
case XlaCompiler::Argument::kParameter:
case XlaCompiler::Argument::kResource:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ limitations under the License.

#include <utility>
#include <variant>
#include <vector>

#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"

namespace tensorflow {

Expand All @@ -34,7 +36,11 @@ struct DeviceCompilationClusterSignature {
// argument number. Tensors must be in host memory.
using TensorTypeAndShape =
std::pair<DataType, absl::InlinedVector<int64_t, 4>>;
absl::InlinedVector<std::variant<Tensor, TensorTypeAndShape>, 8> args;
struct ConstantTensor {
Tensor value;
std::vector<xla::ExpressionProto> contents;
};
absl::InlinedVector<std::variant<ConstantTensor, TensorTypeAndShape>, 8> args;

bool operator==(const DeviceCompilationClusterSignature& other) const;

Expand Down
4 changes: 4 additions & 0 deletions tensorflow/compiler/jit/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector<Flag>* flag_list) {
Flag("tf_xla_enable_dynamic_sizes",
&mark_for_compilation_flags->tf_xla_enable_dynamic_sizes,
"Enable dynamic sizes support."),
Flag("tf_xla_enable_symbolic_content",
&mark_for_compilation_flags->tf_xla_enable_symbolic_content,
"Enable symbolic content propagation."),
Flag("tf_xla_persistent_cache_directory",
&mark_for_compilation_flags->tf_xla_persistent_cache_directory,
"If non-empty, JIT-compiled executables are saved to and loaded "
Expand Down Expand Up @@ -262,6 +265,7 @@ void AllocateAndParseFlags() {
->tf_xla_disable_resource_variable_safety_checks_for_debugging = false;
mark_for_compilation_flags->tf_xla_deterministic_cluster_names = false;
mark_for_compilation_flags->tf_xla_enable_dynamic_sizes = false;
mark_for_compilation_flags->tf_xla_enable_symbolic_content = false;
mark_for_compilation_flags->tf_xla_persistent_cache_directory = "";
mark_for_compilation_flags->tf_xla_persistent_cache_device_types = "";
mark_for_compilation_flags->tf_xla_persistent_cache_read_only = false;
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/compiler/jit/flags.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ struct MarkForCompilationPassFlags {
// If true enables support of dynamic sizes.
bool tf_xla_enable_dynamic_sizes;

// If true enables symbolic content propagation.
bool tf_xla_enable_symbolic_content;

// If non-empty, JIT-compiled executables are saved to and loaded from the
// specified file system directory path.
std::string tf_xla_persistent_cache_directory;
Expand Down
92 changes: 85 additions & 7 deletions tensorflow/compiler/jit/kernels/xla_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ limitations under the License.
#include "xla/pjrt/pjrt_client.h"
#include "xla/printer.h"
#include "xla/service/gpu/gpu_executable_run_options.h"
#include "xla/shape_dynexpr.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/protobuf/error_codes.pb.h"
#include "tensorflow/core/framework/allocator.h"
Expand Down Expand Up @@ -384,7 +385,6 @@ GetXlaCompilerArgsAndSnapshotVariables(
return result;
}


std::unique_ptr<DimExpr> ExprFromProto(const ExpressionProto& proto) {
switch (proto.node_type_case()) {
case ExpressionProto::kConstantValue:
Expand Down Expand Up @@ -448,7 +448,6 @@ static xla::DExpr DimExprToDExpr(const DimExpr* e) {
return xla::DExpr::Unknown();
}


absl::Status CompileToLocalExecutable(
OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars,
const XlaPlatformInfo& platform_info,
Expand Down Expand Up @@ -511,6 +510,71 @@ absl::Status CompileToLocalExecutable(
XlaBatchMatcher* xla_batch_matcher =
xla_device_compiler->xla_batch_matcher();
std::optional<xla::DExpr> dynamic_dim_expr;
auto maybe_attach_shape_contents_from_attrs =
[&](int arg_index, const auto& attr_map,
const std::string& node_name) {
auto& arg = norm_args[arg_index];
if (arg.kind != XlaCompiler::Argument::kConstant) {
return;
}

bool has_dynamic = false;
auto has_dynamic_it = attr_map.find("has_dynamic");
if (has_dynamic_it == attr_map.end()) {
return;
}
has_dynamic = has_dynamic_it->second.b();
if (!has_dynamic) {
return;
}

auto inferred_shape_it = attr_map.find("user_inferred_shape");
if (inferred_shape_it == attr_map.end()) {
LOG(INFO) << "XlaCompileOp saw has_dynamic for const arg "
<< arg_index << " node=" << node_name
<< " but no user_inferred_shape attr";
return;
}

TensorShapeProto inferred_shape_proto;
inferred_shape_proto = inferred_shape_it->second.shape();

TensorShape inferred_shape(inferred_shape_proto);
if (!TensorShapeUtils::IsVector(arg.constant_value.shape()) ||
arg.constant_value.NumElements() != inferred_shape.dims()) {
LOG(INFO) << "XlaCompileOp const arg " << arg_index
<< " node=" << node_name
<< " has dynamic shape metadata but tensor shape "
<< arg.constant_value.shape().DebugString()
<< " does not match inferred rank " << inferred_shape.dims();
return;
}

arg.constant_value_expressions.clear();
arg.constant_value_expressions.reserve(inferred_shape.dims());
for (int64_t i = 0; i < inferred_shape.dims(); ++i) {
xla::ExpressionProto expr;
const xla::DExpr& dim_expr = inferred_shape.get_expression(i);
if (dim_expr && dim_expr->is_dynamic()) {
dim_expr->to_proto(&expr);
} else if (arg.constant_value.dtype() == DT_INT32) {
expr.set_constant_value(arg.constant_value.flat<int32>()(i));
} else if (arg.constant_value.dtype() == DT_INT64) {
expr.set_constant_value(arg.constant_value.flat<int64_t>()(i));
} else {
LOG(INFO) << "XlaCompileOp const arg " << arg_index
<< " node=" << node_name
<< " has unsupported dtype for inferred shape contents: "
<< DataTypeString(arg.constant_value.dtype());
arg.constant_value_expressions.clear();
return;
}
arg.constant_value_expressions.push_back(std::move(expr));
}
LOG(INFO) << "XlaCompileOp recovered " << arg.constant_value_expressions.size()
<< " constant_value_expressions for const arg " << arg_index
<< " node=" << node_name << " from user_inferred_shape";
};
auto record_dynamic_dim_value = [&](int64_t dim_size, xla::DExpr expr) {
if (!saw_dynamic_dim_value) {
saw_dynamic_dim_value = true;
Expand All @@ -536,6 +600,7 @@ absl::Status CompileToLocalExecutable(
VLOG(1) << "XlaCompileOp retrieved shape-derived marker for arg "
<< arg_index << " node=" << node_name;
}
maybe_attach_shape_contents_from_attrs(arg_index, attr_map, node_name);

// Special case for _dynamic_dim...
auto dyn_dim_attr = attr_map.find("_dynamic_dim");
Expand Down Expand Up @@ -632,6 +697,21 @@ absl::Status CompileToLocalExecutable(
return;
}

auto set_constant_contents = [&]<typename T>(int rewrite_index) {
arg.constant_value_expressions.clear();
const int64_t num_elements = arg.constant_value.NumElements();
arg.constant_value_expressions.reserve(num_elements);
for (int64_t i = 0; i < num_elements; ++i) {
xla::ExpressionProto expr;
if (i == rewrite_index) {
dynamic_dim_expr->to_proto(&expr);
} else {
expr.set_constant_value(arg.constant_value.flat<T>()(i));
}
arg.constant_value_expressions.push_back(std::move(expr));
}
};

if (arg.constant_value.dtype() == DT_INT32) {
auto flat = arg.constant_value.flat<int32>();
int rewrite_index = -1;
Expand All @@ -650,9 +730,8 @@ absl::Status CompileToLocalExecutable(
VLOG(1) << "XlaCompileOp int32 constant arg " << arg_index
<< " index " << rewrite_index
<< " matches dynamic_dim_value=" << dynamic_dim_value;
arg.dynamic_constant_index = rewrite_index;
arg.dynamic_constant_expr = dynamic_dim_expr;
mutable_flat(rewrite_index) = filled_batch;
set_constant_contents.template operator()<int32>(rewrite_index);
}
} else if (arg.constant_value.dtype() == DT_INT64) {
auto flat = arg.constant_value.flat<int64_t>();
Expand All @@ -670,9 +749,8 @@ absl::Status CompileToLocalExecutable(
VLOG(1) << "XlaCompileOp int64 constant arg " << arg_index
<< " index " << rewrite_index
<< " matches dynamic_dim_value=" << dynamic_dim_value;
arg.dynamic_constant_index = rewrite_index;
arg.dynamic_constant_expr = dynamic_dim_expr;
mutable_flat(rewrite_index) = filled_batch;
set_constant_contents.template operator()<int64_t>(rewrite_index);
}
}
};
Expand All @@ -687,7 +765,7 @@ absl::Status CompileToLocalExecutable(
TensorShape& shp = std::get<TensorShape>(norm_args[i].shape);
for (int j = 0; j < shp.get_expressions().size(); ++j) {
auto e = shp.get_expression(j);
if (e->is_dynamic()) {
if (e && e->is_dynamic()) {
int64_t old = shp.dim_size(j);
old_vars.push_back({i, j, old});
xla::DExpr padded_expr = xla::DExpr::Const(filled_batch);
Expand Down
36 changes: 33 additions & 3 deletions tensorflow/compiler/jit/mark_for_compilation_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -923,15 +923,45 @@ absl::StatusOr<bool> MarkForCompilationPassImpl::Initialize() {
if (debug_options_.enable_dynamic_sizes) {
LogExpressionsViaGraphProperties(*graph_);
TF_RETURN_IF_ERROR(AssignDimVars());
auto has_dynamic_input_expression = [&](const Node* n) {
for (const Edge* edge : n->in_edges()) {
if (edge->IsControlEdge()) {
continue;
}
const Node* src = edge->src();
auto it = expr_map.find(src->name());
if (it == expr_map.end()) {
continue;
}
const int output_index = edge->src_output();
if (output_index < 0 || output_index >= it->second.size()) {
continue;
}
for (const auto& expr_ptr : it->second[output_index]) {
if (expr_ptr == nullptr) {
continue;
}
xla::DExpr dyn = DimExprToDExpr(expr_ptr.get());
if (dyn && dyn->is_dynamic()) {
return true;
}
}
}
return false;
};
for (Node* n : graph_->op_nodes()) {
bool mark_shape_derived = false;
if (n->type_string() == "Shape" || n->type_string() == "ShapeN") {
mark_shape_derived = true;
mark_shape_derived = has_dynamic_input_expression(n);
} else if (n->type_string() == "Cast") {
for (const Edge* edge : n->in_edges()) {
if (edge->IsControlEdge()) continue;
if (edge->IsControlEdge()) {
continue;
}
const Node* src = edge->src();
if (src->type_string() == "Shape" || src->type_string() == "ShapeN") {
if ((src->type_string() == "Shape" ||
src->type_string() == "ShapeN") &&
has_dynamic_input_expression(src)) {
mark_shape_derived = true;
break;
}
Expand Down
Loading