Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions tensorflow/compiler/jit/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,8 @@ cc_library(
":flags_headers",
":tf_graph_to_hlo_compiler",
":xla_compile_util",
":xla_batch_matcher",
"//tensorflow/compiler/jit:flags",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/core:framework",
"//tensorflow/core:framework_lite",
Expand Down Expand Up @@ -2005,3 +2007,13 @@ tf_cuda_cc_test(
"@local_xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client",
],
)

cc_library(
name = "xla_batch_matcher",
srcs = ["xla_batch_matcher.cc"],
hdrs = ["xla_batch_matcher.h"],
deps = [
"//tensorflow/core/platform:logging",
"@local_xla//xla:debug_options_flags",
],
)
24 changes: 23 additions & 1 deletion tensorflow/compiler/jit/device_compilation_profiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include <utility>

#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/xla_activity.pb.h"
#include "tensorflow/compiler/jit/xla_activity_listener.h"
#include "tensorflow/core/framework/attr_value.pb.h"
Expand All @@ -32,9 +33,30 @@ limitations under the License.
namespace tensorflow {
namespace {
bool ShouldBeMegamorphic(int64_t compile_count, int64_t execution_count) {
const int64_t kCompileThreshold = 10;
int64_t kCompileThreshold = 10;
const int64_t kMinExecutionsPerCompile = 50;

int64_t tf_xla_threshold_for_megamorphic =
GetMarkForCompilationPassFlags()->tf_xla_threshold_for_megamorphic;

// Negative values other that -1 cannot be used
if (tf_xla_threshold_for_megamorphic < -1) {
LOG(FATAL) << "The value for the tf_xla_threshold_for_megamorphic flag "
<< "is out of range.\n"
<< "Allowed ranges are (-1) to "
<< std::numeric_limits<int64_t>::max()
<< " got " << tf_xla_threshold_for_megamorphic << ".";
}

// -1: setting clusters as Megamorphic is disabled
// 0 Default behaviour in Tensorflow
// Any other number sets the compilation threshold
if (tf_xla_threshold_for_megamorphic == -1) {
return false;
} else if (tf_xla_threshold_for_megamorphic > 0) {
kCompileThreshold = tf_xla_threshold_for_megamorphic;
}

// This heuristic is trying to capture the following property: have we sunk a
// certain minimum amount of compile time into the cluster that didn't quite
// "pay off"?
Expand Down
12 changes: 11 additions & 1 deletion tensorflow/compiler/jit/device_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/tf_graph_to_hlo_compiler.h"
#include "tensorflow/compiler/jit/xla_compile_util.h"
#include "tensorflow/compiler/jit/xla_batch_matcher.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/core/framework/metrics.h"
#include "tensorflow/core/framework/op_kernel.h"
Expand Down Expand Up @@ -125,7 +126,8 @@ class DeviceCompiler : public ResourceBase {
DeviceCompilerClient<ExecutableType, ClientType>* compiler_client() {
return compiler_client_.get();
}

XlaBatchMatcher* xla_batch_matcher() { return xla_batch_matcher_.get(); }

string DebugString() const override;

private:
Expand Down Expand Up @@ -177,6 +179,9 @@ class DeviceCompiler : public ResourceBase {
// Pool of threads for asynchronous compilations.
std::unique_ptr<thread::ThreadPool> async_compiler_threads_;

// Specified dynamic batch padding values.
std::unique_ptr<XlaBatchMatcher> xla_batch_matcher_;

mutex cluster_mutexes_mu_;
absl::flat_hash_map<DeviceCompilationClusterSignature, std::unique_ptr<mutex>,
DeviceCompilationClusterSignature::Hash>
Expand Down Expand Up @@ -225,6 +230,11 @@ DeviceCompiler<ExecutableType, ClientType>::DeviceCompiler(
async_compiler_threads_ = std::make_unique<tensorflow::thread::ThreadPool>(
tensorflow::Env::Default(), "async_compiler_threads",
kNumAsyncDeviceCompilerThreads);

MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
if (flags->tf_xla_enable_dynamic_sizes) {
xla_batch_matcher_ = std::make_unique<XlaBatchMatcher>();
}
}

template <typename ExecutableType, typename ClientType>
Expand Down
98 changes: 95 additions & 3 deletions tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/encapsulate_util.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
Expand Down Expand Up @@ -54,9 +55,14 @@ limitations under the License.
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/device_name_utils.h"
#include "tensorflow/core/util/dump_graph.h"

#include "tensorflow/core/framework/tensor_shape.pb.h"
namespace tensorflow {

static const absl::flat_hash_set<absl::string_view> kFailingOps = {
"Where",
// add more here
};

const char* const kXlaCompiledKernelAttr = "_XlaCompiledKernel";
const char* const kXlaNumConstantArgsAttr = "_XlaNumConstantArgs";
const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs";
Expand Down Expand Up @@ -114,6 +120,30 @@ void MarkGuaranteedConstants(
}
}

// Helper to convert ExpressionProto to a readable string.
std::string ExprProtoToString(const ExpressionProto& e) {
switch (e.node_type_case()) {
case ExpressionProto::kConstantValue:
return std::to_string(e.constant_value());
case ExpressionProto::kVariableId:
return absl::StrCat("Var(", e.variable_id(), ")");
case ExpressionProto::kAddNode:
return absl::StrCat("(", ExprProtoToString(e.add_node().lhs()), " + ",
ExprProtoToString(e.add_node().rhs()), ")");
case ExpressionProto::kSubNode:
return absl::StrCat("(", ExprProtoToString(e.sub_node().lhs()), " - ",
ExprProtoToString(e.sub_node().rhs()), ")");
case ExpressionProto::kMulNode:
return absl::StrCat("(", ExprProtoToString(e.mul_node().lhs()), " * ",
ExprProtoToString(e.mul_node().rhs()), ")");
case ExpressionProto::kDivNode:
return absl::StrCat("(", ExprProtoToString(e.div_node().lhs()), " / ",
ExprProtoToString(e.div_node().rhs()), ")");
default:
return "<none>";
}
}

struct OutputInputTensorPairHasher {
uint64 operator()(std::pair<OutputTensor, InputTensor> const& s) const {
return Hash64Combine(OutputTensor::Hash()(s.first),
Expand Down Expand Up @@ -369,6 +399,19 @@ class Encapsulator {

namespace {

bool BuildOutputShapeProto(const Node& node, int output_slot,
TensorShapeProto* proto) {
AttrSlice attrs = node.attrs();
auto shape_attr =
attrs.FindByString(kXlaInferredOutputTensorShapesAttrName);
if (shape_attr == nullptr || !shape_attr->has_list() ||
shape_attr->list().shape_size() <= output_slot) {
return false;
}
*proto = shape_attr->list().shape(output_slot);
return true;
}

// Return in 'sorted' a topological sort of clusters according to the
// dependencies encoded in ancestors. clusters is the list of all clusters
// including clusters that are not present in the ancestors map. has_successors
Expand Down Expand Up @@ -451,6 +494,31 @@ Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) {

Graph* Encapsulator::Subgraph::GetGraph() const { return graph_.get(); }

void ExprToProto(xla::DynExpr* expr, ExpressionProto* proto) {
auto e = expr->s();
if (xla::Constant* c = dynamic_cast<xla::Constant*>(e)) {
proto->set_constant_value(c->get_val());
} else if (xla::Variable* v = dynamic_cast<xla::Variable*>(e)) {
proto->set_variable_id(v->get_id());
} else if (xla::Add* a = dynamic_cast<xla::Add*>(e)) {
auto* add_msg = proto->mutable_add_node();
ExprToProto(a->get_lhs(), add_msg->mutable_lhs());
ExprToProto(a->get_rhs(), add_msg->mutable_rhs());
} else if (xla::Mul* m = dynamic_cast<xla::Mul*>(e)) {
auto* mul_msg = proto->mutable_mul_node();
ExprToProto(m->get_lhs(), mul_msg->mutable_lhs());
ExprToProto(m->get_rhs(), mul_msg->mutable_rhs());
} else if (xla::Sub* s = dynamic_cast<xla::Sub*>(e)) {
auto* sub_msg = proto->mutable_sub_node();
ExprToProto(s->get_lhs(), sub_msg->mutable_lhs());
ExprToProto(s->get_rhs(), sub_msg->mutable_rhs());
} else if (xla::Div* d = dynamic_cast<xla::Div*>(e)) {
auto* div_msg = proto->mutable_div_node();
ExprToProto(d->get_lhs(), div_msg->mutable_lhs());
ExprToProto(d->get_rhs(), div_msg->mutable_rhs());
}
}

absl::Status Encapsulator::Subgraph::RecordArg(
const Edge* edge,
const absl::flat_hash_map<const Node*, Node*>& node_images,
Expand All @@ -470,6 +538,22 @@ absl::Status Encapsulator::Subgraph::RecordArg(
DataType dtype = edge->dst()->input_type(edge->dst_input());
builder.Attr("T", dtype);
builder.Attr("index", arg_index);
AttrSlice attrs = src_node->attrs();
TensorShapeProto output_shape_proto;
if (BuildOutputShapeProto(*src_node, src_slot, &output_shape_proto)) {
VLOG(1) << "Adding following output shapes for node " << src_node->name()
<< " : " << output_shape_proto.DebugString();
builder.Attr("_output_shapes", {output_shape_proto});
builder.Attr(kXlaInferredOutputShapesAttrName, {output_shape_proto});
} else {
// if cluster argument is the real argument.
auto build_attr = attrs.FindByString("_dynamic_dim");
if (build_attr) {
VLOG(1) << "Found Dynamic dimension in " << src_node->name() << ":"
<< src_slot;
builder.Attr("_dynamic_dim", *build_attr);
}
}
absl::Status s = builder.Finalize(&arg_def);
if (!s.ok()) return s;

Expand Down Expand Up @@ -1143,6 +1227,14 @@ static absl::Status RenumberArguments(Graph* graph,
return absl::OkStatus();
}

static bool SubgraphHasFailingOps(const Graph& g) {
for (Node* n : g.op_nodes()) {
if (n->IsRetval()) continue;
if (kFailingOps.contains(n->def().op())) return true;
}
return false;
}

absl::Status EncapsulateSubgraphsPass::Run(
const GraphOptimizationPassOptions& options) {
VLOG(1) << "EncapsulateSubgraphsPass::Run";
Expand Down Expand Up @@ -1289,8 +1381,8 @@ absl::Status EncapsulateSubgraphsPass::Run(

// TODO(phawkins): add a forward is-constant analysis, similarly split
// outputs into host-memory constants and device-memory non-constants.

AddNodeAttr(kXlaCompiledKernelAttr, true, node);
bool compile_enabled = !SubgraphHasFailingOps(**subgraph);
AddNodeAttr(kXlaCompiledKernelAttr, compile_enabled, node);
AddNodeAttr(kXlaNumConstantArgsAttr, num_consts, node);
AddNodeAttr(kXlaNumResourceArgsAttr, num_resources, node);
return absl::OkStatus();
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/compiler/jit/encapsulate_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,10 @@ absl::Status PostprocessControlEdgesBetweenOutsideCompilations(
} // namespace

const char kXlaInferredShapesAttrName[] = "_xla_inferred_shapes";
const char kXlaInferredOutputTensorShapesAttrName[] =
"_xla_inferred_output_tensor_shapes";
const char kXlaInferredOutputShapesAttrName[] =
"_xla_inferred_output_shapes";

const char kXlaConnectedToXlaComputationAttrName[] =
"_xla_connected_to_xla_computation";
Expand Down
9 changes: 9 additions & 0 deletions tensorflow/compiler/jit/encapsulate_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ namespace tensorflow {
// a list of PartialTensorShape objects.
extern const char kXlaInferredShapesAttrName[];

// Attribute marking Grappler-inferred output TensorShapeProtos. Attribute
// value is a list of TensorShapeProto objects and may include ExpressionProto
// annotations when available.
extern const char kXlaInferredOutputTensorShapesAttrName[];

// Attribute carrying inferred output TensorShapeProtos on encapsulated _Arg
// nodes for XlaCompileOp argument reconstruction.
extern const char kXlaInferredOutputShapesAttrName[];

// Infers output shapes for all nodes in graph `g`. The output shapes will be
// stored in node attribute `kXlaInferredShapesAttrName`.
//
Expand Down
21 changes: 21 additions & 0 deletions tensorflow/compiler/jit/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,15 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector<Flag>* flag_list) {
Flag("tf_xla_max_cluster_size",
&mark_for_compilation_flags->tf_xla_max_cluster_size,
"Maximum number of operators in an XLA compilation."),
Flag("tf_xla_annotate_cluster_id",
&mark_for_compilation_flags->tf_xla_annotate_cluster_id,
"Allow operator names to influence clustering scheme."
"Operators whose name starting with .cluster.{id} will likely"
"to be clustered together if the ids are the same number. "
".cluster.none will not be clustered with those having numbered id"),
Flag("tf_xla_cluster_parallel",
&mark_for_compilation_flags->tf_xla_cluster_parallel,
"Split parallel compute subgraph info different clusters"),
Flag(
"tf_xla_ops_to_cluster",
&mark_for_compilation_flags->tf_xla_ops_to_cluster,
Expand Down Expand Up @@ -155,6 +164,9 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector<Flag>* flag_list) {
&mark_for_compilation_flags->tf_xla_deterministic_cluster_names,
"Causes the function names assigned by auto clustering to be "
"deterministic from run to run."),
Flag("tf_xla_enable_dynamic_sizes",
&mark_for_compilation_flags->tf_xla_enable_dynamic_sizes,
"Enable dynamic sizes support."),
Flag("tf_xla_persistent_cache_directory",
&mark_for_compilation_flags->tf_xla_persistent_cache_directory,
"If non-empty, JIT-compiled executables are saved to and loaded "
Expand All @@ -175,6 +187,11 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector<Flag>* flag_list) {
&mark_for_compilation_flags->tf_xla_persistent_cache_prefix,
"Specifies the persistance cache prefix. Default is "
"\"xla_compile_cache\""),
Flag("tf_xla_threshold_for_megamorphic",
&mark_for_compilation_flags->tf_xla_threshold_for_megamorphic,
"Sets the threshold for marking a cluster megamorphic. "
"Setting it to -1 disables marking clusters megamorphic."
"Setting it to 0 uses the default behaviour of TensorFlow."),
Flag("tf_xla_sparse_core_disable_table_stacking",
&sparse_core_flags->tf_xla_sparse_core_disable_table_stacking,
"Disable table stacking for all the tables passed to the SparseCore"
Expand Down Expand Up @@ -232,15 +249,19 @@ void AllocateAndParseFlags() {
mark_for_compilation_flags->tf_xla_min_cluster_size = 4;
mark_for_compilation_flags->tf_xla_max_cluster_size =
std::numeric_limits<int32>::max();
mark_for_compilation_flags->tf_xla_annotate_cluster_id = false;
mark_for_compilation_flags->tf_xla_cluster_parallel = false;
mark_for_compilation_flags->tf_xla_clustering_debug = false;
mark_for_compilation_flags->tf_xla_cpu_global_jit = false;
mark_for_compilation_flags->tf_xla_clustering_fuel =
std::numeric_limits<int64_t>::max();
mark_for_compilation_flags->tf_xla_threshold_for_megamorphic = 0;
mark_for_compilation_flags
->tf_xla_disable_deadness_safety_checks_for_debugging = false;
mark_for_compilation_flags
->tf_xla_disable_resource_variable_safety_checks_for_debugging = false;
mark_for_compilation_flags->tf_xla_deterministic_cluster_names = false;
mark_for_compilation_flags->tf_xla_enable_dynamic_sizes = false;
mark_for_compilation_flags->tf_xla_persistent_cache_directory = "";
mark_for_compilation_flags->tf_xla_persistent_cache_device_types = "";
mark_for_compilation_flags->tf_xla_persistent_cache_read_only = false;
Expand Down
14 changes: 14 additions & 0 deletions tensorflow/compiler/jit/flags.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ struct MarkForCompilationPassFlags {
// Maximum number of operators in an XLA compilation.
int32 tf_xla_max_cluster_size;

// Enable operator name to influence clustering decision
bool tf_xla_annotate_cluster_id;

// Split parallel compute subgraph info different clusters
bool tf_xla_cluster_parallel;

// If non-empty, limit XLA clustering to the following TF operations.
string tf_xla_ops_to_cluster;

Expand Down Expand Up @@ -93,6 +99,9 @@ struct MarkForCompilationPassFlags {
// so that they remain stable from run to run of auto clusteing.
bool tf_xla_deterministic_cluster_names;

// If true enables support of dynamic sizes.
bool tf_xla_enable_dynamic_sizes;

// If non-empty, JIT-compiled executables are saved to and loaded from the
// specified file system directory path.
std::string tf_xla_persistent_cache_directory;
Expand All @@ -111,6 +120,11 @@ struct MarkForCompilationPassFlags {

// Specifies the persistance cache prefix. Default is "xla_compile_cache"
string tf_xla_persistent_cache_prefix;

// Sets the threshold for marking a cluster megamorphic.
// Setting it to -1 disables marking clusters megamorphic.
// Setting it to 0 uses the default behaviour of TensorFlow.
int64_t tf_xla_threshold_for_megamorphic;
};

// Flags associated with XLA Sparse Core.
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/jit/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ cc_library(
"//tensorflow/compiler/jit:tf_graph_to_hlo_compiler",
"//tensorflow/compiler/jit:tf_to_hlo_compiler",
"//tensorflow/compiler/jit:xla_compile_util",
"//tensorflow/compiler/jit:xla_batch_matcher",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:refcount",
"@com_google_absl//absl/base:core_headers",
Expand Down
Loading
Loading